dynamo_parsers/tool_calling/json/
base_json_parser.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5
6use regex::RegexBuilder;
7use serde_json::Value;
8use uuid::Uuid;
9
10use super::config::JsonParserConfig;
11use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
12
13// Same as CalledFunction with named parameters
14#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
15pub struct CalledFunctionParameters {
16    pub name: String,
17    pub parameters: HashMap<String, Value>,
18}
19
20// Same as CalledFunction with named parameters
21#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
22pub struct CalledFunctionArguments {
23    pub name: String,
24    pub arguments: HashMap<String, Value>,
25}
26
27// Extract the contents between start and end tokens using regex parsing.
28// Returns a JSON array string if there are multiple matches, otherwise returns the last match directly.
29fn extract_tool_call_content(input: &str, start_token: &str, end_token: &str) -> Option<String> {
30    let escaped_start = regex::escape(start_token);
31    let escaped_end = regex::escape(end_token);
32    let pattern = format!(r"{}(.*?){}", escaped_start, escaped_end);
33
34    match RegexBuilder::new(&pattern)
35        .dot_matches_new_line(true)
36        .build()
37    {
38        Ok(regex) => {
39            // Get all matches and take the last one for now. TODO: Handle multiple tool calls
40            let matches: Vec<_> = regex
41                .captures_iter(input)
42                .filter_map(|captures| captures.get(1))
43                .map(|m| m.as_str().trim().to_string())
44                .collect();
45            if !matches.is_empty() {
46                // If only one match, return it directly, otherwise return as a JSON array string
47                if matches.len() == 1 {
48                    // Return the last match directly
49                    return Some(matches.last().unwrap().clone());
50                } else {
51                    // Join the matches into a JSON array string
52                    return Some(format!("[{}]", matches.join(",")));
53                }
54            }
55            None
56        }
57        Err(_) => None,
58    }
59}
60
61// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
62// Handles single tool and multiple tool call cases for single start_token like <|python_tag|>
63fn handle_single_token_tool_calls(input: &str, start_token: &str) -> Option<String> {
64    // Return the input if it doesn't contain the start token
65    if !input.contains(start_token) {
66        return None;
67    }
68
69    // Split on the start token and keep only JSON-looking segments
70    let mut items: Vec<String> = Vec::new();
71    for seg in input.split(start_token) {
72        let s = seg.trim();
73        if s.is_empty() {
74            continue;
75        }
76        // Only consider segments that start like JSON (objects or arrays)
77        if s.starts_with('{') {
78            // Trim trailing non-JSON by cutting at the last closing brace
79            if let Some(pos) = s.rfind('}') {
80                let candidate = &s[..=pos].trim();
81                // Keep only valid JSON candidates
82                if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
83                    items.push(candidate.to_string());
84                }
85            }
86        } else if s.starts_with('[') {
87            // Handle array format (like phi4: functools[{...}])
88            if let Some(pos) = s.rfind(']') {
89                let candidate = &s[..=pos].trim();
90                // Keep only valid JSON arrays
91                if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
92                    // For arrays, we need to extract the individual objects
93                    if let Ok(serde_json::Value::Array(arr)) =
94                        serde_json::from_str::<serde_json::Value>(candidate)
95                    {
96                        for item in arr {
97                            if let Ok(item_str) = serde_json::to_string(&item) {
98                                items.push(item_str);
99                            }
100                        }
101                    }
102                }
103            }
104        }
105    }
106    if items.is_empty() {
107        // If we found the start token but no valid JSON after it, return empty string
108        // to avoid leaking the invalid content (important for phi4 and similar models)
109        return Some(String::new());
110    }
111    Some(format!("[{}]", items.join(",")))
112}
113
114fn try_parse_normal_text(input: &str, start_token: &str) -> String {
115    // If input contains start token, just take the part before it
116    if let Some(idx) = input.find(start_token) {
117        return input[..idx].trim().to_string();
118    }
119
120    // No start token found, return empty string
121    String::new()
122}
123
124/// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format.
125///
126/// This is a flexible helper that handles a variety of potential formats emitted by LLMs for function/tool calls,
127/// including wrapped payloads (`<TOOLCALL>[...]</TOOLCALL>`, `<|python_tag|>...`) and JSON representations
128/// with either `parameters` or `arguments` fields.
129///
130/// # Supported Formats
131///
132/// The input `message` may be one of:
133///
134/// - `<TOOLCALL>[{ "name": ..., "parameters": { ... } }]</TOOLCALL>`
135/// - `<|python_tag|>{ "name": ..., "arguments": { ... } }`
136/// - Raw JSON of:
137///     - `CalledFunctionParameters`: `{ "name": ..., "parameters": { ... } }`
138///     - `CalledFunctionArguments`: `{ "name": ..., "arguments": { ... } }`
139///     - Or a list of either of those types: `[ { "name": ..., "arguments": { ... } }, ... ]`
140///
141/// # Return
142///
143/// - `Ok(Some(ToolCallResponse))` if parsing succeeds
144/// - `Ok(None)` if input format is unrecognized or invalid JSON
145/// - `Err(...)` if JSON is valid but deserialization or argument re-serialization fails
146///
147/// # Note on List Handling
148///
149/// When the input contains a list of tool calls (either with `parameters` or `arguments`),
150/// only the **last item** in the list is returned. This design choice assumes that the
151/// most recent tool call in a list is the one to execute.
152///
153/// # Errors
154///
155/// Returns a `Result::Err` only if an inner `serde_json::to_string(...)` fails
156/// (e.g., if the arguments are not serializable).
157///
158/// # Examples
159///
160/// ```ignore
161/// let input = r#"<TOOLCALL>[{ "name": "search", "parameters": { "query": "rust" } }]</TOOLCALL>"#;
162/// let result = try_tool_call_parse_json(input)?;
163/// assert!(result.is_some());
164/// ```
165pub fn try_tool_call_parse_basic_json(
166    message: &str,
167    config: &JsonParserConfig,
168) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
169    // Log the config we are using
170    tracing::debug!("Using JSON parser config: {:?}", config);
171    let trimmed = message.trim();
172
173    // Early exit if no content
174    if trimmed.is_empty() {
175        return Ok((vec![], Some(String::new())));
176    }
177
178    let tool_call_start_tokens = &config.tool_call_start_tokens;
179    let tool_call_end_tokens = &config.tool_call_end_tokens;
180
181    // Early exit if no tokens configured
182    if tool_call_start_tokens.is_empty() {
183        return Ok((vec![], Some(trimmed.to_string())));
184    }
185
186    // Iterate over all start and end tokens and try to extract the content between them
187    // Assumption : One message will not contain different tags for tool calls. Iteration over tags is to support different tags by default for multiple models
188    let mut json = trimmed.to_string();
189    let mut normal_text = trimmed.to_string();
190    let mut found_start_token_with_no_valid_json = false;
191
192    // First, check if ANY start token exists in the input
193    let has_start_token = tool_call_start_tokens
194        .iter()
195        .any(|token| !token.is_empty() && normal_text.contains(token));
196
197    if !has_start_token {
198        // No start tokens found, try to extract JSON directly. Everything that starts with { or [ is considered a potential JSON.
199        if let Some(idx) = normal_text.find(['{', '[']) {
200            let extracted_normal = normal_text[..idx].trim().to_string();
201            let extracted_json = normal_text[idx..].trim().to_string();
202            if !extracted_json.is_empty() {
203                normal_text = extracted_normal;
204                json = extracted_json;
205            }
206        }
207    } else {
208        // Start tokens exist, use regex-based parsing
209        // Try all combinations of start and end tokens
210        'outer: for start_token in tool_call_start_tokens.iter() {
211            for end_token in tool_call_end_tokens.iter() {
212                let new_normal_text = try_parse_normal_text(&normal_text, start_token);
213
214                // Process based on token types
215                match (start_token.is_empty(), end_token.is_empty()) {
216                    (false, true) => {
217                        // Single token case
218                        let result = handle_single_token_tool_calls(&json, start_token);
219                        if let Some(content) = result {
220                            // Check if we found a start token but got empty JSON back
221                            // This indicates the token was found but no valid JSON followed
222                            if content.is_empty() {
223                                found_start_token_with_no_valid_json = true;
224                            }
225
226                            json = content;
227                            // For single token case, use the normal text we extracted earlier
228                            normal_text = new_normal_text;
229
230                            break 'outer; // Found content, exit early
231                        }
232                    }
233                    (false, false) => {
234                        // Start and end token case
235                        let result = extract_tool_call_content(&json, start_token, end_token);
236                        if let Some(content) = result {
237                            // Check if we found a start token but got empty JSON back
238                            // This indicates the token was found but no valid JSON followed
239                            if content.is_empty() {
240                                found_start_token_with_no_valid_json = true;
241                            }
242
243                            json = content;
244                            normal_text = new_normal_text;
245
246                            break 'outer; // Found content, exit early
247                        }
248                    }
249                    _ => {
250                        continue;
251                    }
252                }
253            }
254        }
255    }
256    // Convert json (String) to &str
257    let json = json.as_str();
258    // Anonymous function to attempt deserialization into a known representation
259    let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<ToolCallResponse> {
260        // Preserve nested JSON strings intact; do not double-escape.
261        // serde_json::to_string on Value preserves required escapes only.
262        Ok(ToolCallResponse {
263            id: format!("call-{}", Uuid::new_v4()),
264            tp: ToolCallType::Function,
265            function: CalledFunction {
266                name,
267                arguments: serde_json::to_string(&args)?,
268            },
269        })
270    };
271
272    // CalledFunctionParameters: Single { name, parameters }
273    // Example:
274    // {
275    //   "name": "search_docs",
276    //   "parameters": {
277    //     "query": "how to use Rust",
278    //     "limit": 5
279    //   }
280    // }
281    if let Ok(single) = serde_json::from_str::<CalledFunctionParameters>(json) {
282        return Ok((
283            vec![parse(single.name, single.parameters)?],
284            Some(normal_text),
285        ));
286        //parse(single.name, single.parameters).map(Some);
287
288        // CalledFunctionArguments: Single { name, arguments }
289        // Example:
290        // {
291        //   "name": "summarize",
292        //   "arguments": {
293        //     "text": "Rust is a systems programming language.",
294        //     "length": "short"
295        //   }
296        // }
297    } else if let Ok(single) = serde_json::from_str::<CalledFunctionArguments>(json) {
298        return Ok((
299            vec![parse(single.name, single.arguments)?],
300            Some(normal_text),
301        ));
302
303    // Vec<CalledFunctionParameters> or Vec<CalledFunctionArguments>: Array of tool calls
304    // Example:
305    // [
306    //   { "name": "lookup_user", "parameters": { "user_id": "123" } },
307    //   { "name": "get_weather", "arguments": { "location": "SF", "units": "celsius" } }
308    // ]
309    // Parse as generic array to handle both formats and malformed entries gracefully
310    // Note: Always return once we parse a valid array, even if empty or with malformed entries
311    } else if let Ok(array) = serde_json::from_str::<Vec<serde_json::Value>>(json) {
312        let mut results = Vec::new();
313        for item in array {
314            // Try both CalledFunctionArguments and CalledFunctionParameters formats
315            if let Ok(func_args) = serde_json::from_value::<CalledFunctionArguments>(item.clone()) {
316                results.push(parse(func_args.name, func_args.arguments)?);
317            } else if let Ok(func_params) = serde_json::from_value::<CalledFunctionParameters>(item)
318            {
319                results.push(parse(func_params.name, func_params.parameters)?);
320            }
321            // Skip malformed entries silently
322        }
323        // Return with whatever results we have, even if empty (e.g., [] is a valid empty array)
324        return Ok((results, Some(normal_text)));
325    }
326
327    // If we found a start token but no valid JSON, return empty content
328    // to avoid leaking the token and invalid JSON content
329    if found_start_token_with_no_valid_json {
330        Ok((vec![], Some(String::new())))
331    } else {
332        Ok((vec![], Some(trimmed.to_string())))
333    }
334}
335
336pub fn detect_tool_call_start_basic_json(chunk: &str, config: &JsonParserConfig) -> bool {
337    let trimmed = chunk.trim();
338    if trimmed.is_empty() {
339        return false;
340    }
341
342    // Check if chunk contains any complete start token
343    let contains_complete_token = config
344        .tool_call_start_tokens
345        .iter()
346        .any(|token| !token.is_empty() && trimmed.contains(token));
347
348    if contains_complete_token {
349        return true;
350    }
351
352    // Check for partial start tokens (streaming scenario)
353    // This handles cases where start tokens are split across multiple chunks
354    let has_partial_token = config.tool_call_start_tokens.iter().any(|token| {
355        if token.is_empty() {
356            return false;
357        }
358        // Check if the chunk could be a prefix of this start token
359        // Handle Unicode character boundaries properly
360        for i in 1..=token.chars().count() {
361            if let Some(prefix) = token.chars().take(i).collect::<String>().get(..) {
362                let prefix_str = &prefix[..prefix.len()];
363                // Check for exact prefix match
364                if trimmed == prefix_str {
365                    return true;
366                }
367                // For longer prefixes (3+ chars), allow them anywhere in the input
368                // This allows "funny joke" to match "functools" via "fun"
369                // but prevents "<tool_call>" from matching "<TOOLCALL>" via single char "<"
370                if prefix_str.len() >= 3 && trimmed.contains(prefix_str) {
371                    return true;
372                }
373                // For shorter prefixes, only match if they're at the end (streaming scenario)
374                if prefix_str.len() < 3 && trimmed.ends_with(prefix_str) {
375                    return true;
376                }
377            }
378        }
379        false
380    });
381
382    has_partial_token || trimmed.contains('{') || trimmed.contains('[')
383}
384
385#[cfg(test)]
386mod detect_parser_tests {
387    use super::*;
388
389    #[test]
390    fn detect_tool_call_start_basic_json_chunk_with_tool_call_start_token_hermes() {
391        let text =
392            r#"<tool_call>{"name": "search", "parameters": { "query": "rust" } }</tool_call>"#;
393        let config = JsonParserConfig {
394            tool_call_start_tokens: vec!["<tool_call>".to_string()],
395            tool_call_end_tokens: vec!["</tool_call>".to_string()],
396            ..Default::default()
397        };
398        let result = detect_tool_call_start_basic_json(text, &config);
399        assert!(result);
400    }
401
402    #[test]
403    fn detect_tool_call_start_basic_json_chunk_without_tool_call_start_token() {
404        let text = r#"{"name": "search", "parameters": { "query": "rust" } }"#;
405        let config = JsonParserConfig {
406            tool_call_start_tokens: vec!["<tool_call>".to_string()],
407            tool_call_end_tokens: vec!["</tool_call>".to_string()],
408            ..Default::default()
409        };
410        let result = detect_tool_call_start_basic_json(text, &config);
411        assert!(result);
412    }
413
414    #[test]
415    fn detect_tool_call_start_basic_json_chunk_without_tool_call_start_token_with_normal_text() {
416        let text = r#"Here it is {"name": "#;
417        let config = JsonParserConfig {
418            tool_call_start_tokens: vec!["<tool_call>".to_string()],
419            tool_call_end_tokens: vec!["</tool_call>".to_string()],
420            ..Default::default()
421        };
422        let result = detect_tool_call_start_basic_json(text, &config);
423        assert!(result);
424    }
425
426    #[test]
427    fn detect_tool_call_start_basic_json_chunk_with_square_brackets() {
428        // These kind of false positives are expected when calling this function for stream=True
429        let text = r#"Here it is [{"name": "search","#;
430        let config = JsonParserConfig {
431            tool_call_start_tokens: vec!["<tool_call>".to_string()],
432            tool_call_end_tokens: vec!["</tool_call>".to_string()],
433            ..Default::default()
434        };
435        let result = detect_tool_call_start_basic_json(text, &config);
436        assert!(result);
437    }
438
439    #[test]
440    fn detect_tool_call_start_basic_json_chunk_false_positive() {
441        // These kind of false positives are expected when calling this function for stream=True
442        let text = r#"Here it is { Whats up"#;
443        let config = JsonParserConfig {
444            tool_call_start_tokens: vec!["<tool_call>".to_string()],
445            tool_call_end_tokens: vec!["</tool_call>".to_string()],
446            ..Default::default()
447        };
448        let result = detect_tool_call_start_basic_json(text, &config);
449        assert!(result);
450    }
451
452    #[test]
453    fn detect_tool_call_start_basic_json_chunk_with_tool_call_start_token_nemotron_deci() {
454        let text =
455            r#"<TOOLCALL>[{"name": "search", "parameters": { "query": "rust" } }]</TOOLCALL>"#;
456        let config = JsonParserConfig {
457            tool_call_start_tokens: vec!["<TOOLCALL>".to_string()],
458            tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
459            ..Default::default()
460        };
461        let result = detect_tool_call_start_basic_json(text, &config);
462        assert!(result);
463    }
464
465    #[test]
466    fn detect_tool_call_start_basic_json_chunk_with_lllama3_json_token() {
467        let text = r#"<|python_tag|>{ "name": }"#;
468        let config = JsonParserConfig {
469            tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
470            tool_call_end_tokens: vec!["".to_string()],
471            ..Default::default()
472        };
473        let result = detect_tool_call_start_basic_json(text, &config);
474        assert!(result);
475    }
476
477    #[test]
478    fn detect_tool_call_start_basic_json_chunk_mistral_token() {
479        let text = r#"Hello Yo ! [TOOL_CALLS]{"name": "search", "#;
480        let config = JsonParserConfig {
481            tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
482            tool_call_end_tokens: vec!["".to_string()],
483            ..Default::default()
484        };
485        let result = detect_tool_call_start_basic_json(text, &config);
486        assert!(result);
487    }
488
489    #[test]
490    fn detect_tool_call_start_basic_json_chunk_phi4_token() {
491        let text = r#"functools{"name": "search", "#;
492        let config = JsonParserConfig {
493            tool_call_start_tokens: vec!["functools".to_string()],
494            tool_call_end_tokens: vec!["".to_string()],
495            ..Default::default()
496        };
497        let result = detect_tool_call_start_basic_json(text, &config);
498        assert!(result);
499    }
500
501    #[test]
502    fn detect_tool_call_start_basic_json_chunk_phi4_partial_token_fun() {
503        // Test the streaming scenario where "fun" arrives first
504        let text = r#"fun"#;
505        let config = JsonParserConfig {
506            tool_call_start_tokens: vec!["functools".to_string()],
507            tool_call_end_tokens: vec!["".to_string()],
508            ..Default::default()
509        };
510        let result = detect_tool_call_start_basic_json(text, &config);
511        assert!(
512            result,
513            "Should detect 'fun' as potential start of 'functools'"
514        );
515    }
516
517    #[test]
518    fn detect_tool_call_start_basic_json_chunk_phi4_partial_token_func() {
519        let text = r#"func"#;
520        let config = JsonParserConfig {
521            tool_call_start_tokens: vec!["functools".to_string()],
522            tool_call_end_tokens: vec!["".to_string()],
523            ..Default::default()
524        };
525        let result = detect_tool_call_start_basic_json(text, &config);
526        assert!(
527            result,
528            "Should detect 'func' as potential start of 'functools'"
529        );
530    }
531
532    #[test]
533    fn detect_tool_call_start_basic_json_chunk_phi4_partial_token_f() {
534        let text = r#"f"#;
535        let config = JsonParserConfig {
536            tool_call_start_tokens: vec!["functools".to_string()],
537            tool_call_end_tokens: vec!["".to_string()],
538            ..Default::default()
539        };
540        let result = detect_tool_call_start_basic_json(text, &config);
541        assert!(
542            result,
543            "Should detect 'f' as potential start of 'functools'"
544        );
545    }
546
547    #[test]
548    fn detect_tool_call_start_basic_json_chunk_phi4_partial_with_prefix() {
549        // Test case where text ends with a partial token (more realistic streaming scenario)
550        let text = r#"Hello fun"#;
551        let config = JsonParserConfig {
552            tool_call_start_tokens: vec!["functools".to_string()],
553            tool_call_end_tokens: vec!["".to_string()],
554            ..Default::default()
555        };
556        let result = detect_tool_call_start_basic_json(text, &config);
557        assert!(
558            result,
559            "Should detect text ending with 'fun' as potential tool call start"
560        );
561    }
562
563    #[test]
564    fn detect_tool_call_start_basic_json_chunk_phi4_avoid_false_positive() {
565        // Test to ensure we don't get false positives for unrelated text
566        let text = r#"funny joke"#;
567        let config = JsonParserConfig {
568            tool_call_start_tokens: vec!["functools".to_string()],
569            tool_call_end_tokens: vec!["".to_string()],
570            ..Default::default()
571        };
572        let result = detect_tool_call_start_basic_json(text, &config);
573        // This should still return true because "fun" is a prefix, but that's expected behavior
574        // The key is that we detect potential starts, and false positives are acceptable
575        // in streaming scenarios to avoid missing real tool calls
576        assert!(result);
577    }
578
579    #[test]
580    fn detect_tool_call_start_basic_json_chunk_phi4_no_match() {
581        let text = r#"hello world"#;
582        let config = JsonParserConfig {
583            tool_call_start_tokens: vec!["functools".to_string()],
584            tool_call_end_tokens: vec!["".to_string()],
585            ..Default::default()
586        };
587        let result = detect_tool_call_start_basic_json(text, &config);
588        assert!(
589            !result,
590            "Should not detect unrelated text as tool call start"
591        );
592    }
593}