Skip to main content

atomcode_core/turn/
json_repair.rs

1/// JSON repair utilities for malformed LLM tool-call output.
2///
3/// LLMs frequently produce JSON with issues such as trailing commas, single quotes,
4/// unquoted keys, invalid backslash escapes, and markdown code fences.
5/// These functions attempt to repair such output before falling back to
6/// last-resort key-value extraction.
7
8/// Normalize tool-call arguments into valid JSON before execution.
9///
10/// Runs the repair chain: direct parse → repair_json → tool-specific extractor →
11/// generic key-value extraction. Returns the original string unchanged if all
12/// strategies fail (caller can then surface a parse error to the model).
13///
14/// `tool_name` selects a specialized extractor when available (e.g. `edit_file`
15/// which may contain unescaped source code in `old_string`/`new_string`).
16pub fn repair_tool_args(tool_name: &str, args: &str) -> String {
17    // Fast path: already valid JSON.
18    if serde_json::from_str::<serde_json::Value>(args).is_ok() {
19        return args.to_string();
20    }
21    // Generic JSON repair (trailing commas, unquoted keys, fence strip, etc.).
22    let repaired = repair_json(args);
23    if serde_json::from_str::<serde_json::Value>(&repaired).is_ok() {
24        return repaired;
25    }
26    // Specialized: edit_file often ships source code with unescaped quotes/newlines.
27    if tool_name == "edit_file" {
28        if let Some(v) = extract_edit_file_args(args) {
29            if let Ok(s) = serde_json::to_string(&v) {
30                return s;
31            }
32        }
33    }
34    // Last resort: key-value field extraction. Only return this if it actually
35    // recovered something — an empty object is no better than the original garbage.
36    let extracted = extract_json_fields(args);
37    if let Some(obj) = extracted.as_object() {
38        if !obj.is_empty() {
39            if let Ok(s) = serde_json::to_string(&extracted) {
40                return s;
41            }
42        }
43    }
44    args.to_string()
45}
46
47/// Attempt to repair common JSON issues from LLM output:
48/// - Trailing commas before } or ]
49/// - Single quotes instead of double quotes (outside of string values)
50/// - Missing closing braces
51/// - Unescaped newlines in strings
52/// - Invalid backslash escapes
53/// - Unquoted keys
54/// - Missing commas between key-value pairs
55/// - Markdown code fences
56pub fn repair_json(s: &str) -> String {
57    let mut result = s.to_string();
58
59    // Fix invalid JSON backslash escapes: \. \( \) \| \w \d \s \+ \* etc.
60    // JSON only allows: \\ \" \/ \n \r \t \b \f \uXXXX
61    // Models often write regex like @app\.(get|post) which has \. — invalid in JSON.
62    // Fix by doubling the backslash: \. → \\. so JSON parses it as literal backslash + dot.
63    let valid_escapes = ['\\', '"', '/', 'n', 'r', 't', 'b', 'f', 'u'];
64    let chars: Vec<char> = result.chars().collect();
65    let mut fixed = String::with_capacity(result.len() + 20);
66    let mut i = 0;
67    while i < chars.len() {
68        if chars[i] == '\\' && i + 1 < chars.len() {
69            let next = chars[i + 1];
70            if valid_escapes.contains(&next) {
71                // Valid JSON escape — keep as-is
72                fixed.push('\\');
73                fixed.push(next);
74                i += 2;
75            } else {
76                // Invalid JSON escape (like \. \( \| \w \d \s \+ \*)
77                // Double the backslash so JSON parser sees \\ followed by the char
78                fixed.push('\\');
79                fixed.push('\\');
80                fixed.push(next);
81                i += 2;
82            }
83        } else {
84            fixed.push(chars[i]);
85            i += 1;
86        }
87    }
88    result = fixed;
89
90    // Remove leading/trailing whitespace and any markdown code fences
91    result = result.trim().to_string();
92    if result.starts_with("```json") {
93        result = result
94            .strip_prefix("```json")
95            .unwrap_or(&result)
96            .to_string();
97    }
98    if result.starts_with("```") {
99        result = result.strip_prefix("```").unwrap_or(&result).to_string();
100    }
101    if result.ends_with("```") {
102        result = result.strip_suffix("```").unwrap_or(&result).to_string();
103    }
104    result = result.trim().to_string();
105
106    // Replace single quotes with double quotes for keys/values
107    // Be careful not to break strings containing apostrophes
108    // Simple heuristic: replace ' at JSON structural positions
109    if !result.contains('"') && result.contains('\'') {
110        result = result.replace('\'', "\"");
111    }
112
113    // Fix missing commas between key-value pairs: }" " → }", "
114    // Pattern: value followed by whitespace then another key
115    // e.g., {"path": "src" "depth": 2} → {"path": "src", "depth": 2}
116    let mut chars: Vec<char> = result.chars().collect();
117    let mut insertions = Vec::new();
118    let mut i = 0;
119    while i < chars.len() {
120        // Look for pattern: " <whitespace> " where the second " starts a key
121        if chars[i] == '"' {
122            let j = i + 1;
123            // Skip whitespace
124            let mut k = j;
125            while k < chars.len() && chars[k].is_whitespace() {
126                k += 1;
127            }
128            // If next non-whitespace is " and it looks like a key (followed by :), insert comma
129            if k < chars.len() && chars[k] == '"' && k > j {
130                // Check if this looks like key: find the closing " then :
131                let mut q = k + 1;
132                while q < chars.len() && chars[q] != '"' {
133                    q += 1;
134                }
135                if q + 1 < chars.len() {
136                    let mut r = q + 1;
137                    while r < chars.len() && chars[r].is_whitespace() {
138                        r += 1;
139                    }
140                    if r < chars.len() && chars[r] == ':' {
141                        // This is a missing comma: insert after position i
142                        insertions.push(j);
143                    }
144                }
145            }
146        }
147        i += 1;
148    }
149    // Insert commas in reverse order to preserve indices
150    for pos in insertions.into_iter().rev() {
151        chars.insert(pos, ',');
152    }
153    result = chars.into_iter().collect();
154
155    // Fix unquoted keys: {path: "src"} → {"path": "src"}
156    // Simple approach: find patterns like {key: or ,key: and add quotes
157    let mut fixed = String::with_capacity(result.len() + 20);
158    let rchars: Vec<char> = result.chars().collect();
159    let mut ri = 0;
160    while ri < rchars.len() {
161        if rchars[ri] == '{' || rchars[ri] == ',' {
162            fixed.push(rchars[ri]);
163            ri += 1;
164            // Skip whitespace
165            while ri < rchars.len() && rchars[ri].is_whitespace() {
166                fixed.push(rchars[ri]);
167                ri += 1;
168            }
169            // Check if next is an unquoted key (alphanumeric/underscore followed by :)
170            if ri < rchars.len() && rchars[ri].is_alphanumeric() {
171                let key_start = ri;
172                while ri < rchars.len() && (rchars[ri].is_alphanumeric() || rchars[ri] == '_') {
173                    ri += 1;
174                }
175                // Skip whitespace after key
176                let mut ki = ri;
177                while ki < rchars.len() && rchars[ki].is_whitespace() {
178                    ki += 1;
179                }
180                if ki < rchars.len() && rchars[ki] == ':' {
181                    // Unquoted key — add quotes
182                    fixed.push('"');
183                    for c in &rchars[key_start..ri] {
184                        fixed.push(*c);
185                    }
186                    fixed.push('"');
187                } else {
188                    // Not a key, just copy
189                    for c in &rchars[key_start..ri] {
190                        fixed.push(*c);
191                    }
192                }
193            }
194        } else {
195            fixed.push(rchars[ri]);
196            ri += 1;
197        }
198    }
199    result = fixed;
200
201    // Remove trailing commas before } or ]
202    loop {
203        let before = result.clone();
204        result = result.replace(",}", "}").replace(",]", "]");
205        if result == before {
206            break;
207        }
208    }
209
210    // If it doesn't start with { or [, wrap it
211    if !result.starts_with('{') && !result.starts_with('[') {
212        result = format!("{{{}}}", result);
213    }
214
215    // Count braces and add missing closing ones
216    let open_braces = result.chars().filter(|c| *c == '{').count();
217    let close_braces = result.chars().filter(|c| *c == '}').count();
218    for _ in 0..(open_braces.saturating_sub(close_braces)) {
219        result.push('}');
220    }
221
222    result
223}
224
225/// Last-resort: extract ALL key-value pairs from malformed JSON by string matching.
226/// Tool-agnostic — no hardcoded field lists. Finds any `"key": "value"` or `key: value` pattern.
227pub fn extract_json_fields(s: &str) -> serde_json::Value {
228    let mut map = serde_json::Map::new();
229    let chars: Vec<char> = s.chars().collect();
230    let len = chars.len();
231    let mut i = 0;
232
233    while i < len {
234        // Find a key: either "key" or bare_key followed by :
235        let key = if chars[i] == '"' {
236            // Quoted key
237            let start = i + 1;
238            i = start;
239            while i < len && chars[i] != '"' {
240                i += 1;
241            }
242            if i >= len {
243                break;
244            }
245            let k: String = chars[start..i].iter().collect();
246            i += 1; // skip closing "
247            k
248        } else if chars[i].is_alphabetic() || chars[i] == '_' {
249            // Bare key
250            let start = i;
251            while i < len && (chars[i].is_alphanumeric() || chars[i] == '_') {
252                i += 1;
253            }
254            chars[start..i].iter().collect()
255        } else {
256            i += 1;
257            continue;
258        };
259
260        // Skip whitespace, expect :
261        while i < len && chars[i].is_whitespace() {
262            i += 1;
263        }
264        if i >= len || chars[i] != ':' {
265            continue;
266        }
267        i += 1; // skip :
268        while i < len && chars[i].is_whitespace() {
269            i += 1;
270        }
271        if i >= len {
272            break;
273        }
274
275        // Read value
276        if chars[i] == '"' {
277            // String value — extract and unescape JSON escape sequences
278            let start = i + 1;
279            i = start;
280            while i < len && chars[i] != '"' {
281                if chars[i] == '\\' {
282                    i += 1;
283                }
284                i += 1;
285            }
286            let raw: String = chars[start..i.min(len)].iter().collect();
287            // Unescape JSON sequences: \n → newline, \t → tab, \" → quote, \\ → backslash
288            let val = raw
289                .replace("\\n", "\n")
290                .replace("\\t", "\t")
291                .replace("\\\"", "\"")
292                .replace("\\\\", "\\");
293            map.insert(key, serde_json::json!(val));
294            if i < len {
295                i += 1;
296            }
297        } else if chars[i] == 't' || chars[i] == 'f' {
298            // Boolean
299            let start = i;
300            while i < len && chars[i].is_alphabetic() {
301                i += 1;
302            }
303            let word: String = chars[start..i].iter().collect();
304            match word.as_str() {
305                "true" => {
306                    map.insert(key, serde_json::json!(true));
307                }
308                "false" => {
309                    map.insert(key, serde_json::json!(false));
310                }
311                _ => {
312                    map.insert(key, serde_json::json!(word));
313                }
314            }
315        } else if chars[i].is_ascii_digit() || chars[i] == '-' {
316            // Number
317            let start = i;
318            while i < len && (chars[i].is_ascii_digit() || chars[i] == '.' || chars[i] == '-') {
319                i += 1;
320            }
321            let num_str: String = chars[start..i].iter().collect();
322            if let Ok(n) = num_str.parse::<i64>() {
323                map.insert(key, serde_json::json!(n));
324            } else if let Ok(f) = num_str.parse::<f64>() {
325                map.insert(key, serde_json::json!(f));
326            }
327        } else {
328            // Unquoted string value — read until , } ]
329            let start = i;
330            while i < len && !matches!(chars[i], ',' | '}' | ']' | '\n') {
331                i += 1;
332            }
333            let val: String = chars[start..i]
334                .iter()
335                .collect::<String>()
336                .trim()
337                .to_string();
338            if !val.is_empty() {
339                map.insert(key, serde_json::json!(val));
340            }
341        }
342    }
343
344    serde_json::Value::Object(map)
345}
346
347/// Specialized parser for edit_file arguments when JSON parsing fails.
348/// Models often generate old_string/new_string with unescaped quotes/newlines.
349/// This parser uses the known field order to extract content by position.
350pub fn extract_edit_file_args(raw: &str) -> Option<serde_json::Value> {
351    let fp_marker = raw.find("\"file_path\"")?;
352    let old_marker = raw.find("\"old_string\"")?;
353    let new_marker = raw.find("\"new_string\"")?;
354    if old_marker <= fp_marker || new_marker <= old_marker {
355        return None;
356    }
357
358    // Extract file_path (simple quoted string before old_string)
359    let fp_region = &raw[fp_marker + 11..old_marker];
360    let fp_colon = fp_region.find(':')?;
361    let fp_val = fp_region[fp_colon + 1..]
362        .trim()
363        .trim_matches(|c| c == '"' || c == ',')
364        .trim();
365    if fp_val.is_empty() {
366        return None;
367    }
368    let file_path = fp_val.to_string();
369
370    // Extract old_string: everything between "old_string": " and ", "new_string"
371    let old_colon = raw[old_marker..].find(':')?;
372    let old_start = old_marker + old_colon + 1;
373    let old_raw = &raw[old_start..new_marker];
374    let old_string = unescape_field_value(old_raw);
375
376    // Extract new_string: everything after "new_string": " to the end
377    let new_colon = raw[new_marker..].find(':')?;
378    let new_start = new_marker + new_colon + 1;
379    let new_raw = &raw[new_start..];
380    let new_string = unescape_field_value_end(new_raw);
381
382    if old_string.is_empty() && new_string.is_empty() {
383        return None;
384    }
385
386    let replace_all = raw.contains("\"replace_all\"")
387        && raw.rfind("true").map_or(false, |t| {
388            raw.rfind("\"replace_all\"").map_or(false, |r| t > r)
389        });
390
391    Some(serde_json::json!({
392        "file_path": file_path,
393        "old_string": old_string,
394        "new_string": new_string,
395        "replace_all": replace_all,
396    }))
397}
398
399fn unescape_field_value(raw: &str) -> String {
400    let t = raw.trim().trim_end_matches(',').trim();
401    let inner = if t.starts_with('"') { &t[1..] } else { t };
402    let inner = inner.trim_end_matches('"');
403    inner
404        .replace("\\n", "\n")
405        .replace("\\t", "\t")
406        .replace("\\\"", "\"")
407        .replace("\\\\", "\\")
408}
409
410fn unescape_field_value_end(raw: &str) -> String {
411    let t = raw.trim();
412    let inner = if t.starts_with('"') { &t[1..] } else { t };
413    // Remove trailing "} or ", "replace_all": ... }
414    let end = inner
415        .rfind("\", \"replace_all\"")
416        .or_else(|| inner.rfind("\"}"))
417        .or_else(|| inner.rfind("\"\n}"))
418        .unwrap_or(inner.len());
419    let content = &inner[..end];
420    content
421        .replace("\\n", "\n")
422        .replace("\\t", "\t")
423        .replace("\\\"", "\"")
424        .replace("\\\\", "\\")
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    // --- repair_json tests ---
432
433    #[test]
434    fn repair_trailing_comma() {
435        let input = r#"{"key": "value",}"#;
436        let repaired = repair_json(input);
437        let parsed: serde_json::Value =
438            serde_json::from_str(&repaired).expect("should be valid JSON");
439        assert_eq!(parsed["key"], "value");
440    }
441
442    #[test]
443    fn repair_single_quotes() {
444        let input = "{'key': 'value'}";
445        let repaired = repair_json(input);
446        let parsed: serde_json::Value =
447            serde_json::from_str(&repaired).expect("should be valid JSON");
448        assert_eq!(parsed["key"], "value");
449    }
450
451    #[test]
452    fn repair_missing_closing_brace() {
453        let input = r#"{"key": "value""#;
454        let repaired = repair_json(input);
455        let parsed: serde_json::Value =
456            serde_json::from_str(&repaired).expect("should be valid JSON");
457        assert_eq!(parsed["key"], "value");
458    }
459
460    #[test]
461    fn repair_unquoted_keys() {
462        let input = r#"{path: "src/main.rs"}"#;
463        let repaired = repair_json(input);
464        let parsed: serde_json::Value =
465            serde_json::from_str(&repaired).expect("should be valid JSON");
466        assert_eq!(parsed["path"], "src/main.rs");
467    }
468
469    #[test]
470    fn repair_invalid_backslash_escape() {
471        // \. is not a valid JSON escape — should be doubled to \\.
472        let input = r#"{"pattern": "app\.rs"}"#;
473        let repaired = repair_json(input);
474        let parsed: serde_json::Value =
475            serde_json::from_str(&repaired).expect("should be valid JSON after escape repair");
476        // After repair \. becomes \\. which JSON parses as literal backslash + dot
477        assert!(parsed["pattern"].as_str().unwrap().contains('.'));
478    }
479
480    #[test]
481    fn repair_missing_comma_between_fields() {
482        let input = r#"{"path": "src" "depth": 2}"#;
483        let repaired = repair_json(input);
484        // Should either parse or at least not panic
485        let _ = serde_json::from_str::<serde_json::Value>(&repaired);
486    }
487
488    #[test]
489    fn repair_markdown_fence_json() {
490        let input = "```json\n{\"key\": \"value\"}\n```";
491        let repaired = repair_json(input);
492        let parsed: serde_json::Value =
493            serde_json::from_str(&repaired).expect("should strip fences");
494        assert_eq!(parsed["key"], "value");
495    }
496
497    #[test]
498    fn repair_markdown_fence_no_lang() {
499        let input = "```\n{\"key\": \"value\"}\n```";
500        let repaired = repair_json(input);
501        let parsed: serde_json::Value =
502            serde_json::from_str(&repaired).expect("should strip fences");
503        assert_eq!(parsed["key"], "value");
504    }
505
506    // --- extract_json_fields tests ---
507
508    #[test]
509    fn extract_fields_basic_key_value() {
510        let input = r#"{"file_path": "/src/main.rs", "pattern": "hello"}"#;
511        let result = extract_json_fields(input);
512        assert_eq!(result["file_path"], "/src/main.rs");
513        assert_eq!(result["pattern"], "hello");
514    }
515
516    #[test]
517    fn extract_fields_boolean_values() {
518        let input = r#"{"recursive": true, "case_sensitive": false}"#;
519        let result = extract_json_fields(input);
520        assert_eq!(result["recursive"], true);
521        assert_eq!(result["case_sensitive"], false);
522    }
523
524    #[test]
525    fn extract_fields_bare_keys() {
526        let input = r#"{path: "/tmp/foo", depth: 3}"#;
527        let result = extract_json_fields(input);
528        assert_eq!(result["path"], "/tmp/foo");
529    }
530
531    // --- extract_edit_file_args tests ---
532
533    #[test]
534    fn extract_edit_file_standard_escaped_newlines() {
535        let input = r#"{"file_path": "/src/lib.rs", "old_string": "fn old(){\n}", "new_string": "fn new(){\n}"}"#;
536        let result = extract_edit_file_args(input).expect("should parse");
537        assert_eq!(result["file_path"], "/src/lib.rs");
538        // \n sequences in old_string/new_string get unescaped to real newlines
539        assert!(result["old_string"].as_str().unwrap().contains('\n'));
540        assert!(result["new_string"].as_str().unwrap().contains('\n'));
541    }
542
543    #[test]
544    fn extract_edit_file_returns_none_on_missing_markers() {
545        let input = r#"{"file_path": "/src/lib.rs"}"#;
546        assert!(extract_edit_file_args(input).is_none());
547    }
548
549    #[test]
550    fn extract_edit_file_replace_all_true() {
551        let input = r#"{"file_path": "/src/lib.rs", "old_string": "foo", "new_string": "bar", "replace_all": true}"#;
552        let result = extract_edit_file_args(input).expect("should parse");
553        assert_eq!(result["replace_all"], true);
554    }
555
556    // --- repair_tool_args tests ---
557
558    #[test]
559    fn repair_tool_args_passes_valid_json_through() {
560        let input = r#"{"file_path":"/tmp/a.rs","content":"x"}"#;
561        assert_eq!(repair_tool_args("write_file", input), input);
562    }
563
564    #[test]
565    fn repair_tool_args_fixes_fence_wrapped_json() {
566        let input = "```json\n{\"file_path\":\"/tmp/a.rs\",\"content\":\"x\"}\n```";
567        let out = repair_tool_args("write_file", input);
568        let v: serde_json::Value = serde_json::from_str(&out).expect("should parse");
569        assert_eq!(v["file_path"], "/tmp/a.rs");
570    }
571
572    #[test]
573    fn repair_tool_args_keeps_empty_object_untouched() {
574        // Empty `{}` is valid JSON — we must not paper over it by inventing fields.
575        // Callers surface it as a user-visible error instead.
576        assert_eq!(repair_tool_args("write_file", "{}"), "{}");
577    }
578
579    #[test]
580    fn repair_tool_args_returns_original_when_unsalvageable() {
581        // Pure garbage with no extractable key=value pairs → return as-is so
582        // the tool emits the real parse error (not a misleading repaired stub).
583        let input = "!!!";
584        assert_eq!(repair_tool_args("write_file", input), "!!!");
585    }
586}