use serde_json::Value;
#[derive(Debug)]
pub struct ToolCallRepair {
pub args: Value,
pub kind: RepairKind,
pub leading_prefix: String,
pub trailing_suffix: String,
}
#[derive(Debug, PartialEq, Eq)]
pub enum RepairKind {
Preserved,
Repaired,
}
pub fn extract_balanced_json_prefix(raw: &str) -> Option<(String, usize)> {
let mut start = 0;
while start < raw.len() {
let c = raw[start..].chars().next()?;
if c == '{' || c == '[' {
break;
}
start += 1;
}
if start >= raw.len() {
return None;
}
let mut depth = 0;
let mut in_string = false;
let mut escaped = false;
for (i, c) in raw[start..].char_indices() {
if in_string {
if escaped {
escaped = false;
} else if c == '\\' {
escaped = true;
} else if c == '"' {
in_string = false;
}
continue;
}
if c == '"' {
in_string = true;
continue;
}
if c == '{' || c == '[' {
depth += 1;
continue;
}
if c == '}' || c == ']' {
depth -= 1;
if depth == 0 {
let json_end = start + i + 1;
return Some((raw[start..json_end].to_owned(), start));
}
}
}
None
}
#[allow(dead_code)]
pub fn should_attempt_repair(partial_json: &str, delta: &str) -> bool {
if delta.contains('}') || delta.contains(']') {
return true;
}
let trimmed = delta.trim();
trimmed.len() <= 3 && (partial_json.contains('}') || partial_json.contains(']'))
}
fn is_allowed_leading_prefix(prefix: &str) -> bool {
if prefix.is_empty() {
return true;
}
if prefix.len() > 96 {
return false;
}
if let Ok(re) = regex::Regex::new(r#"^[a-z0-9\s"'`.:/_\\-]+$"#) {
if !re.is_match(prefix) {
return false;
}
}
let first_char = prefix.chars().next().unwrap_or(' ');
prefix.len() <= 10
|| first_char == '.'
|| first_char == ':'
|| first_char == '"'
|| first_char == '`'
|| prefix.to_lowercase().starts_with("functions")
|| prefix.to_lowercase().starts_with("tools")
|| prefix.starts_with("function")
|| prefix.starts_with("tool")
}
fn is_allowed_trailing_suffix(suffix: &str) -> bool {
if suffix.is_empty() {
return true;
}
if suffix.len() > 3 {
return false;
}
!suffix.chars().any(|c| {
c.is_whitespace() || c == '{' || c == '[' || c == '}' || c == ']' || c == '"' || c == '\\'
})
}
pub fn try_extract_usable_args(raw: &str) -> Option<ToolCallRepair> {
if raw.trim().is_empty() {
return None;
}
if let Ok(parsed) = serde_json::from_str::<Value>(raw) {
if parsed.is_object() {
return Some(ToolCallRepair {
args: parsed,
kind: RepairKind::Preserved,
leading_prefix: String::new(),
trailing_suffix: String::new(),
});
}
}
let extracted = extract_balanced_json_prefix(raw)?;
let leading_prefix = raw[..extracted.1].trim().to_string();
let json_part = &extracted.0;
let trailing_suffix = raw[extracted.1 + json_part.len()..].trim().to_string();
if !leading_prefix.is_empty() && !is_allowed_leading_prefix(&leading_prefix) {
return None;
}
if !leading_prefix.is_empty() && !trailing_suffix.is_empty() {
if !is_allowed_trailing_suffix(&trailing_suffix) {
return None;
}
}
if let Ok(parsed) = serde_json::from_str::<Value>(json_part) {
if parsed.is_object() {
return Some(ToolCallRepair {
args: parsed,
kind: RepairKind::Repaired,
leading_prefix,
trailing_suffix,
});
}
}
if !json_part.is_empty() {
let fixed = json_part
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\t', "\\t");
if let Ok(parsed) = serde_json::from_str::<Value>(&fixed) {
if parsed.is_object() {
return Some(ToolCallRepair {
args: parsed,
kind: RepairKind::Repaired,
leading_prefix,
trailing_suffix,
});
}
}
}
if json_part.starts_with('{') && json_part.contains(':') {
let incomplete = json_part.to_string();
let fixed_incomplete = fix_incomplete_json(&incomplete);
if let Ok(parsed) = serde_json::from_str::<Value>(&fixed_incomplete) {
if parsed.is_object() && !parsed.as_object().unwrap().is_empty() {
return Some(ToolCallRepair {
args: parsed,
kind: RepairKind::Repaired,
leading_prefix,
trailing_suffix: String::new(), });
}
}
}
None
}
fn fix_incomplete_json(incomplete: &str) -> String {
let trimmed = incomplete.trim();
if trimmed.is_empty() {
return "{}".to_string();
}
if serde_json::from_str::<Value>(trimmed).is_ok() {
return trimmed.to_string();
}
let last_char = trimmed.chars().last().unwrap_or(' ');
if last_char == '"'
|| last_char.is_ascii_digit()
|| last_char == 'n'
|| last_char == 'f'
|| last_char == 't'
{
return format!("{}}}", trimmed);
}
let mut result = trimmed.to_string();
let mut depth = 0;
for c in trimmed.chars() {
match c {
'{' | '[' => depth += 1,
'}' | ']' => depth -= 1,
_ => {}
}
}
while depth > 0 {
result.push('}');
depth -= 1;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_balanced_json_simple() {
let result = extract_balanced_json_prefix(r#"hello {"key": "value"} world"#);
assert!(result.is_some());
let (json, start) = result.unwrap();
assert_eq!(json, r#"{"key": "value"}"#);
assert_eq!(start, 6);
}
#[test]
fn test_extract_balanced_json_nested() {
let result = extract_balanced_json_prefix(r#"garbage {"a": [1, 2, {"b": true}]} tail"#);
assert!(result.is_some());
let (json, _) = result.unwrap();
let parsed: Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["a"].as_array().unwrap().len(), 3);
}
#[test]
fn test_try_extract_usable_args_valid() {
let result = try_extract_usable_args(r#"{"content": "hello"}"#);
assert!(result.is_some());
let repair = result.unwrap();
assert_eq!(repair.kind, RepairKind::Preserved);
assert_eq!(repair.args["content"], "hello");
}
#[test]
fn test_try_extract_usable_args_with_garbage() {
let result = try_extract_usable_args(r#"abc {"content": "hello"} ab"#);
assert!(result.is_some());
let repair = result.unwrap();
assert_eq!(repair.kind, RepairKind::Repaired);
assert_eq!(repair.args["content"], "hello");
assert_eq!(repair.leading_prefix, "abc");
assert_eq!(repair.trailing_suffix, "ab");
}
#[test]
fn test_should_attempt_repair() {
assert!(should_attempt_repair(r#"{"key"#, "}"));
assert!(!should_attempt_repair(r#"{""#, "a"));
assert!(should_attempt_repair(r#"{""#, "x}"));
}
#[test]
fn test_try_extract_usable_args_with_unescaped_newlines() {
let raw = r#"{"content": "line1
line2", "path": "test.rs"}"#;
let result = try_extract_usable_args(raw);
assert!(result.is_some());
let repair = result.unwrap();
assert_eq!(repair.kind, RepairKind::Repaired);
}
}