agent_chain_core/utils/
json.rs

1//! Utilities for JSON parsing.
2//!
3//! Adapted from langchain_core/utils/json.py
4
5use regex::Regex;
6use serde_json::Value;
7
8/// Parse a JSON string that may be missing closing braces.
9///
10/// This function attempts to parse a JSON string that may be incomplete,
11/// such as from a streaming LLM response.
12///
13/// # Arguments
14///
15/// * `s` - The JSON string to parse.
16/// * `strict` - Whether to use strict parsing (disallow control characters in strings).
17///
18/// # Returns
19///
20/// The parsed JSON value, or an error if parsing fails.
21///
22/// # Example
23///
24/// ```
25/// use agent_chain_core::utils::json::parse_partial_json;
26///
27/// let result = parse_partial_json(r#"{"key": "value"}"#, false);
28/// assert!(result.is_ok());
29/// ```
30pub fn parse_partial_json(s: &str, strict: bool) -> Result<Value, JsonParseError> {
31    if let Ok(value) = serde_json::from_str(s) {
32        return Ok(value);
33    }
34
35    let mut new_chars = Vec::new();
36    let mut stack = Vec::new();
37    let mut is_inside_string = false;
38    let mut escaped = false;
39
40    for char in s.chars() {
41        let mut new_char = char.to_string();
42
43        if is_inside_string {
44            if char == '"' && !escaped {
45                is_inside_string = false;
46            } else if char == '\n' && !escaped {
47                new_char = "\\n".to_string();
48            } else if char == '\\' {
49                escaped = !escaped;
50            } else {
51                escaped = false;
52            }
53        } else if char == '"' {
54            is_inside_string = true;
55            escaped = false;
56        } else if char == '{' {
57            stack.push('}');
58        } else if char == '[' {
59            stack.push(']');
60        } else if (char == '}' || char == ']')
61            && let Some(expected) = stack.last()
62        {
63            if *expected == char {
64                stack.pop();
65            } else {
66                return Err(JsonParseError::MismatchedBracket);
67            }
68        }
69
70        new_chars.push(new_char);
71    }
72
73    if is_inside_string {
74        if escaped {
75            new_chars.pop();
76        }
77        new_chars.push("\"".to_string());
78    }
79
80    stack.reverse();
81
82    while !new_chars.is_empty() {
83        let mut attempt = new_chars.join("");
84        for closer in &stack {
85            attempt.push(*closer);
86        }
87
88        match serde_json::from_str::<Value>(&attempt) {
89            Ok(value) => {
90                if strict && contains_control_chars(&attempt) {
91                    return Err(JsonParseError::ControlCharacters);
92                }
93                return Ok(value);
94            }
95            Err(_) => {
96                new_chars.pop();
97            }
98        }
99    }
100
101    serde_json::from_str(s).map_err(|e| JsonParseError::ParseError(e.to_string()))
102}
103
104fn contains_control_chars(s: &str) -> bool {
105    s.chars()
106        .any(|c| c.is_control() && c != '\n' && c != '\r' && c != '\t')
107}
108
109/// Parse a JSON string from a Markdown string.
110///
111/// This function extracts JSON from a Markdown code block if present.
112///
113/// # Arguments
114///
115/// * `json_string` - The Markdown string.
116///
117/// # Returns
118///
119/// The parsed JSON value, or an error if parsing fails.
120///
121/// # Example
122///
123/// ```
124/// use agent_chain_core::utils::json::parse_json_markdown;
125///
126/// let result = parse_json_markdown(r#"```json
127/// {"key": "value"}
128/// ```"#);
129/// assert!(result.is_ok());
130/// ```
131pub fn parse_json_markdown(json_string: &str) -> Result<Value, JsonParseError> {
132    // Try to parse directly first
133    if let Ok(value) = parse_json_inner(json_string) {
134        return Ok(value);
135    }
136
137    // Try to find JSON string within triple backticks (with (?s) for DOTALL)
138    let re = Regex::new(r"(?s)```(?:json)?(.*)").expect("Invalid regex");
139
140    let json_str = if let Some(caps) = re.captures(json_string) {
141        caps.get(1).map_or(json_string, |m| m.as_str())
142    } else {
143        json_string
144    };
145
146    parse_json_inner(json_str)
147}
148
149const JSON_STRIP_CHARS: &[char] = &[' ', '\n', '\r', '\t', '`'];
150
151fn parse_json_inner(json_str: &str) -> Result<Value, JsonParseError> {
152    let json_str = json_str.trim_matches(JSON_STRIP_CHARS);
153
154    let json_str = custom_parser(json_str);
155
156    parse_partial_json(&json_str, false)
157}
158
159fn custom_parser(multiline_string: &str) -> String {
160    // Use (?s) flag to make . match newlines (DOTALL mode)
161    let re = Regex::new(r#"(?s)("action_input"\s*:\s*")(.*?)(")"#).expect("Invalid regex");
162    re.replace_all(multiline_string, |caps: &regex::Captures| {
163        let prefix = caps.get(1).map_or("", |m| m.as_str());
164        let value = caps.get(2).map_or("", |m| m.as_str());
165        let suffix = caps.get(3).map_or("", |m| m.as_str());
166
167        let value = value.replace('\n', "\\n");
168        let value = value.replace('\r', "\\r");
169        let value = value.replace('\t', "\\t");
170        // Escape unescaped quotes within the value
171        let value = escape_unescaped_quotes(&value);
172
173        format!("{}{}{}", prefix, value, suffix)
174    })
175    .to_string()
176}
177
178/// Escape double quotes that are not already escaped
179fn escape_unescaped_quotes(s: &str) -> String {
180    let mut result = String::with_capacity(s.len());
181    let mut chars = s.chars().peekable();
182
183    while let Some(c) = chars.next() {
184        if c == '\\' {
185            result.push(c);
186            if chars.peek().is_some() {
187                result.push(chars.next().unwrap());
188            }
189        } else if c == '"' {
190            result.push('\\');
191            result.push('"');
192        } else {
193            result.push(c);
194        }
195    }
196
197    result
198}
199
200/// Parse a JSON string and check that it contains the expected keys.
201///
202/// # Arguments
203///
204/// * `text` - The Markdown string.
205/// * `expected_keys` - The expected keys in the JSON object.
206///
207/// # Returns
208///
209/// The parsed JSON object, or an error if parsing fails or keys are missing.
210///
211/// # Example
212///
213/// ```
214/// use agent_chain_core::utils::json::parse_and_check_json_markdown;
215///
216/// let result = parse_and_check_json_markdown(r#"{"key": "value"}"#, &["key"]);
217/// assert!(result.is_ok());
218/// ```
219pub fn parse_and_check_json_markdown(
220    text: &str,
221    expected_keys: &[&str],
222) -> Result<Value, JsonParseError> {
223    let json_obj = parse_json_markdown(text)?;
224
225    let obj = json_obj
226        .as_object()
227        .ok_or_else(|| JsonParseError::NotAnObject(format!("{:?}", json_obj)))?;
228
229    for key in expected_keys {
230        if !obj.contains_key(*key) {
231            return Err(JsonParseError::MissingKey(key.to_string()));
232        }
233    }
234
235    Ok(json_obj)
236}
237
238/// Error types for JSON parsing.
239#[derive(Debug, Clone, PartialEq)]
240pub enum JsonParseError {
241    /// Failed to parse JSON.
242    ParseError(String),
243    /// Mismatched bracket in JSON.
244    MismatchedBracket,
245    /// Control characters found in strict mode.
246    ControlCharacters,
247    /// Expected an object but got something else.
248    NotAnObject(String),
249    /// Missing expected key.
250    MissingKey(String),
251}
252
253impl std::fmt::Display for JsonParseError {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        match self {
256            JsonParseError::ParseError(msg) => write!(f, "Failed to parse JSON: {}", msg),
257            JsonParseError::MismatchedBracket => write!(f, "Mismatched bracket in JSON"),
258            JsonParseError::ControlCharacters => write!(f, "Control characters found in JSON"),
259            JsonParseError::NotAnObject(got) => {
260                write!(f, "Expected JSON object (dict), but got: {}", got)
261            }
262            JsonParseError::MissingKey(key) => {
263                write!(f, "Missing expected key: {}", key)
264            }
265        }
266    }
267}
268
269impl std::error::Error for JsonParseError {}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use serde_json::json;
275
276    #[test]
277    fn test_parse_partial_json_complete() {
278        let result = parse_partial_json(r#"{"key": "value"}"#, false).unwrap();
279        assert_eq!(result, json!({"key": "value"}));
280    }
281
282    #[test]
283    fn test_parse_partial_json_incomplete() {
284        let result = parse_partial_json(r#"{"key": "value""#, false).unwrap();
285        assert_eq!(result, json!({"key": "value"}));
286    }
287
288    #[test]
289    fn test_parse_partial_json_array() {
290        let result = parse_partial_json(r#"[1, 2, 3"#, false).unwrap();
291        assert_eq!(result, json!([1, 2, 3]));
292    }
293
294    #[test]
295    fn test_parse_json_markdown() {
296        let markdown = r#"```json
297{"key": "value"}
298```"#;
299        let result = parse_json_markdown(markdown).unwrap();
300        assert_eq!(result, json!({"key": "value"}));
301    }
302
303    #[test]
304    fn test_parse_json_markdown_no_fence() {
305        let result = parse_json_markdown(r#"{"key": "value"}"#).unwrap();
306        assert_eq!(result, json!({"key": "value"}));
307    }
308
309    #[test]
310    fn test_parse_and_check_json_markdown() {
311        let result = parse_and_check_json_markdown(r#"{"key": "value"}"#, &["key"]);
312        assert!(result.is_ok());
313    }
314
315    #[test]
316    fn test_parse_and_check_json_markdown_missing_key() {
317        let result = parse_and_check_json_markdown(r#"{"key": "value"}"#, &["missing"]);
318        assert!(matches!(result, Err(JsonParseError::MissingKey(_))));
319    }
320
321    #[test]
322    fn test_custom_parser() {
323        let input = r#"{"action_input": "line1
324line2"}"#;
325        let result = custom_parser(input);
326        assert!(result.contains("\\n"));
327    }
328}