use crate::provider::{ContentPart, Message, MessageContent, Role};
use serde_json::Value;
pub fn fix_json_backslashes(raw: &str) -> String {
let mut out = String::with_capacity(raw.len() + 32);
let mut in_string = false;
let chars: Vec<char> = raw.chars().collect();
let len = chars.len();
let mut i = 0;
while i < len {
let c = chars[i];
if !in_string {
out.push(c);
if c == '"' {
in_string = true;
}
i += 1;
continue;
}
if c == '"' {
out.push(c);
in_string = false;
i += 1;
continue;
}
if c == '\\' {
let next = chars.get(i + 1).copied().unwrap_or('\0');
if next == '"' || next == '\\' {
out.push(c);
out.push(next);
i += 2;
} else if next == 'u'
&& i + 5 < len
&& chars[i + 2].is_ascii_hexdigit()
&& chars[i + 3].is_ascii_hexdigit()
&& chars[i + 4].is_ascii_hexdigit()
&& chars[i + 5].is_ascii_hexdigit()
{
for j in 0..6 {
out.push(chars[i + j]);
}
i += 6;
} else {
out.push('\\');
out.push('\\');
i += 1;
}
continue;
}
out.push(c);
i += 1;
}
out
}
#[derive(Debug)]
pub struct ToolCallRepair {
pub args: Value,
pub kind: RepairKind,
pub leading_prefix: String,
pub trailing_suffix: String,
}
#[derive(Debug, PartialEq, Eq)]
pub enum RepairKind {
Preserved,
Repaired,
}
pub fn extract_balanced_json_prefix(raw: &str) -> Option<(String, usize)> {
let mut start = 0;
while start < raw.len() {
let c = raw[start..].chars().next()?;
if c == '{' || c == '[' {
break;
}
start += 1;
}
if start >= raw.len() {
return None;
}
let mut depth = 0;
let mut in_string = false;
let mut escaped = false;
for (i, c) in raw[start..].char_indices() {
if in_string {
if escaped {
escaped = false;
} else if c == '\\' {
escaped = true;
} else if c == '"' {
in_string = false;
}
continue;
}
if c == '"' {
in_string = true;
continue;
}
if c == '{' || c == '[' {
depth += 1;
continue;
}
if c == '}' || c == ']' {
depth -= 1;
if depth == 0 {
let json_end = start + i + 1;
return Some((raw[start..json_end].to_owned(), start));
}
}
}
None
}
#[allow(dead_code)]
pub fn should_attempt_repair(partial_json: &str, delta: &str) -> bool {
if delta.contains('}') || delta.contains(']') {
return true;
}
let trimmed = delta.trim();
trimmed.len() <= 3 && (partial_json.contains('}') || partial_json.contains(']'))
}
fn is_allowed_leading_prefix(prefix: &str) -> bool {
if prefix.is_empty() {
return true;
}
if prefix.len() > 96 {
return false;
}
if let Ok(re) = regex::Regex::new(r#"^[a-z0-9\s"'`.:/_\\-]+$"#) {
if !re.is_match(prefix) {
return false;
}
}
let first_char = prefix.chars().next().unwrap_or(' ');
prefix.len() <= 10
|| first_char == '.'
|| first_char == ':'
|| first_char == '"'
|| first_char == '`'
|| prefix.to_lowercase().starts_with("functions")
|| prefix.to_lowercase().starts_with("tools")
|| prefix.starts_with("function")
|| prefix.starts_with("tool")
}
fn is_allowed_trailing_suffix(suffix: &str) -> bool {
if suffix.is_empty() {
return true;
}
if suffix.len() > 3 {
return false;
}
!suffix.chars().any(|c| {
c.is_whitespace() || c == '{' || c == '[' || c == '}' || c == ']' || c == '"' || c == '\\'
})
}
pub fn try_extract_usable_args(raw: &str) -> Option<ToolCallRepair> {
if raw.trim().is_empty() {
return None;
}
if let Ok(parsed) = serde_json::from_str::<Value>(raw) {
if parsed.is_object() {
return Some(ToolCallRepair {
args: parsed,
kind: RepairKind::Preserved,
leading_prefix: String::new(),
trailing_suffix: String::new(),
});
}
}
let extracted = extract_balanced_json_prefix(raw)?;
let leading_prefix = raw.get(..extracted.1)
.unwrap_or("")
.trim()
.to_string();
let json_part = &extracted.0;
let suffix_start = extracted.1 + json_part.len();
let trailing_suffix = raw.get(suffix_start..)
.unwrap_or("")
.trim()
.to_string();
if !leading_prefix.is_empty() && !is_allowed_leading_prefix(&leading_prefix) {
return None;
}
if !leading_prefix.is_empty() && !trailing_suffix.is_empty() {
if !is_allowed_trailing_suffix(&trailing_suffix) {
return None;
}
}
if let Ok(parsed) = serde_json::from_str::<Value>(json_part) {
if parsed.is_object() {
return Some(ToolCallRepair {
args: parsed,
kind: RepairKind::Repaired,
leading_prefix,
trailing_suffix,
});
}
}
if !json_part.is_empty() {
let fixed = fix_json_backslashes(json_part)
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\t', "\\t");
if let Ok(parsed) = serde_json::from_str::<Value>(&fixed) {
if parsed.is_object() {
return Some(ToolCallRepair {
args: parsed,
kind: RepairKind::Repaired,
leading_prefix,
trailing_suffix,
});
}
}
}
if json_part.starts_with('{') && json_part.contains(':') {
let incomplete = json_part.to_string();
let fixed_incomplete = fix_incomplete_json(&incomplete);
if let Ok(parsed) = serde_json::from_str::<Value>(&fixed_incomplete) {
if parsed.is_object() && !parsed.as_object().unwrap().is_empty() {
return Some(ToolCallRepair {
args: parsed,
kind: RepairKind::Repaired,
leading_prefix,
trailing_suffix: String::new(), });
}
}
}
None
}
fn fix_incomplete_json(incomplete: &str) -> String {
let trimmed = incomplete.trim();
if trimmed.is_empty() {
return "{}".to_string();
}
if serde_json::from_str::<Value>(trimmed).is_ok() {
return trimmed.to_string();
}
let mut fixed = trimmed
.replace("\":}", "\":null}")
.replace("\":,", "\":null,")
.replace("\": }", "\": null}")
.replace("\": ,", "\": null,");
fixed = fixed.replace(",}", "}");
if serde_json::from_str::<Value>(&fixed).is_ok() {
return fixed;
}
let last_char = trimmed.chars().last().unwrap_or(' ');
if last_char == '"'
|| last_char.is_ascii_digit()
|| last_char == 'n'
|| last_char == 'f'
|| last_char == 't'
{
return format!("{}}}", trimmed);
}
let mut result = trimmed.to_string();
let mut depth = 0;
for c in trimmed.chars() {
match c {
'{' | '[' => depth += 1,
'}' | ']' => depth -= 1,
_ => {}
}
}
while depth > 0 {
result.push('}');
depth -= 1;
}
result
}
#[derive(Debug)]
pub struct RepairResult {
pub messages: Vec<Message>,
pub synthetic_messages: Vec<Message>,
}
pub fn repair_tool_result_pairing(messages: Vec<Message>) -> RepairResult {
let mut all_tool_ids: Vec<String> = Vec::new();
for msg in &messages {
if msg.role != Role::Assistant {
continue;
}
match &msg.content {
MessageContent::Parts(parts) => {
for part in parts {
if let ContentPart::ToolUse { id, .. } = part {
all_tool_ids.push(id.clone());
}
}
}
_ => {}
}
}
let mut has_result: std::collections::HashSet<String> = std::collections::HashSet::new();
for msg in &messages {
if msg.role != Role::Tool {
continue;
}
match &msg.content {
MessageContent::Parts(parts) => {
for part in parts {
if let ContentPart::ToolResult { tool_use_id, .. } = part {
has_result.insert(tool_use_id.clone());
}
}
}
_ => {}
}
}
let missing_ids: Vec<String> = all_tool_ids
.iter()
.filter(|id| !has_result.contains(*id))
.cloned()
.collect();
let mut repaired: Vec<Message> = Vec::new();
let mut synthetic_messages: Vec<Message> = Vec::new();
for msg in messages {
if msg.role == Role::Tool {
let is_orphan = match &msg.content {
MessageContent::Parts(parts) => {
parts.iter().all(|part| {
if let ContentPart::ToolResult { tool_use_id, .. } = part {
!all_tool_ids.contains(tool_use_id)
} else {
false
}
})
}
_ => false,
};
if is_orphan {
tracing::warn!(
"repair_tool_result_pairing: removing orphaned tool result for unknown tool_call_id"
);
continue; }
}
let tool_ids_in_this_msg: Vec<String> = if msg.role == Role::Assistant {
match &msg.content {
MessageContent::Parts(parts) => {
parts
.iter()
.filter_map(|part| {
if let ContentPart::ToolUse { id, .. } = part {
Some(id.clone())
} else {
None
}
})
.collect()
}
_ => Vec::new(),
}
} else {
Vec::new()
};
repaired.push(msg);
for missing_id in tool_ids_in_this_msg.iter().filter(|id| missing_ids.contains(id)) {
tracing::warn!(
tool_call_id = %missing_id,
"repair_tool_result_pairing: adding synthetic error result for missing tool_call_id"
);
let synthetic = Message {
role: Role::Tool,
content: MessageContent::Parts(vec![
ContentPart::ToolResult {
tool_use_id: missing_id.clone(),
content: "[Session interrupted: tool execution was not completed]".to_owned(),
is_error: Some(true),
},
]),
};
repaired.push(synthetic.clone());
synthetic_messages.push(synthetic);
}
}
RepairResult {
messages: repaired,
synthetic_messages,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_balanced_json_simple() {
let result = extract_balanced_json_prefix(r#"hello {"key": "value"} world"#);
assert!(result.is_some());
let (json, start) = result.unwrap();
assert_eq!(json, r#"{"key": "value"}"#);
assert_eq!(start, 6);
}
#[test]
fn test_extract_balanced_json_nested() {
let result = extract_balanced_json_prefix(r#"garbage {"a": [1, 2, {"b": true}]} tail"#);
assert!(result.is_some());
let (json, _) = result.unwrap();
let parsed: Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["a"].as_array().unwrap().len(), 3);
}
#[test]
fn test_try_extract_usable_args_valid() {
let result = try_extract_usable_args(r#"{"content": "hello"}"#);
assert!(result.is_some());
let repair = result.unwrap();
assert_eq!(repair.kind, RepairKind::Preserved);
assert_eq!(repair.args["content"], "hello");
}
#[test]
fn test_try_extract_usable_args_with_garbage() {
let result = try_extract_usable_args(r#"abc {"content": "hello"} ab"#);
assert!(result.is_some());
let repair = result.unwrap();
assert_eq!(repair.kind, RepairKind::Repaired);
assert_eq!(repair.args["content"], "hello");
assert_eq!(repair.leading_prefix, "abc");
assert_eq!(repair.trailing_suffix, "ab");
}
#[test]
fn test_should_attempt_repair() {
assert!(should_attempt_repair(r#"{"key"#, "}"));
assert!(!should_attempt_repair(r#"{""#, "a"));
assert!(should_attempt_repair(r#"{""#, "x}"));
}
#[test]
fn test_try_extract_usable_args_with_unescaped_newlines() {
let raw = "{\"content\": \"line1\nline2\", \"path\": \"test.rs\"}";
let result = try_extract_usable_args(raw);
assert!(result.is_some());
let repair = result.unwrap();
assert_eq!(repair.kind, RepairKind::Repaired);
}
#[test]
fn test_try_extract_missing_value_before_brace() {
let raw = r#"{"action":"search","query":"所有记忆","top_k":}"#;
let result = try_extract_usable_args(raw);
assert!(result.is_some());
let repair = result.unwrap();
assert_eq!(repair.args["action"], "search");
assert_eq!(repair.args["query"], "所有记忆");
assert!(repair.args["top_k"].is_null());
}
#[test]
fn test_repair_missing_tool_result() {
let messages = vec![
Message {
role: Role::User,
content: MessageContent::Text("hello".to_owned()),
},
Message {
role: Role::Assistant,
content: MessageContent::Parts(vec![
ContentPart::ToolUse {
id: "call_123".to_owned(),
name: "test_tool".to_owned(),
input: serde_json::json!({"arg": "value"}),
},
]),
},
];
let result = repair_tool_result_pairing(messages);
let repaired = &result.messages;
assert_eq!(repaired.len(), 3);
assert_eq!(repaired[0].role, Role::User);
assert_eq!(repaired[1].role, Role::Assistant);
assert_eq!(repaired[2].role, Role::Tool);
match &repaired[2].content {
MessageContent::Parts(parts) => {
assert_eq!(parts.len(), 1);
match &parts[0] {
ContentPart::ToolResult { tool_use_id, content, is_error } => {
assert_eq!(tool_use_id, "call_123");
assert!(content.contains("interrupted"));
assert_eq!(*is_error, Some(true));
}
_ => panic!("Expected ToolResult"),
}
}
_ => panic!("Expected Parts"),
}
assert_eq!(result.synthetic_messages.len(), 1);
}
#[test]
fn test_repair_complete_pairing() {
let messages = vec![
Message {
role: Role::User,
content: MessageContent::Text("hello".to_owned()),
},
Message {
role: Role::Assistant,
content: MessageContent::Parts(vec![
ContentPart::ToolUse {
id: "call_123".to_owned(),
name: "test_tool".to_owned(),
input: serde_json::json!({}),
},
]),
},
Message {
role: Role::Tool,
content: MessageContent::Parts(vec![
ContentPart::ToolResult {
tool_use_id: "call_123".to_owned(),
content: "result".to_owned(),
is_error: Some(false),
},
]),
},
];
let result = repair_tool_result_pairing(messages);
assert_eq!(result.messages.len(), 3);
assert_eq!(result.messages[2].role, Role::Tool);
assert!(result.synthetic_messages.is_empty());
}
#[test]
fn test_repair_multiple_tool_calls() {
let messages = vec![
Message {
role: Role::Assistant,
content: MessageContent::Parts(vec![
ContentPart::ToolUse {
id: "call_1".to_owned(),
name: "tool_a".to_owned(),
input: serde_json::json!({}),
},
ContentPart::ToolUse {
id: "call_2".to_owned(),
name: "tool_b".to_owned(),
input: serde_json::json!({}),
},
]),
},
Message {
role: Role::Tool,
content: MessageContent::Parts(vec![
ContentPart::ToolResult {
tool_use_id: "call_1".to_owned(),
content: "result 1".to_owned(),
is_error: Some(false),
},
]),
},
];
let result = repair_tool_result_pairing(messages);
let repaired = &result.messages;
assert_eq!(repaired.len(), 3);
assert_eq!(repaired[0].role, Role::Assistant);
assert_eq!(repaired[1].role, Role::Tool);
assert_eq!(repaired[2].role, Role::Tool);
match &repaired[1].content {
MessageContent::Parts(parts) => match &parts[0] {
ContentPart::ToolResult { tool_use_id, content, is_error } => {
assert_eq!(tool_use_id, "call_2");
assert!(content.contains("interrupted"));
assert_eq!(*is_error, Some(true));
}
_ => panic!("Expected ToolResult"),
},
_ => panic!("Expected Parts"),
}
match &repaired[2].content {
MessageContent::Parts(parts) => match &parts[0] {
ContentPart::ToolResult { tool_use_id, content, is_error } => {
assert_eq!(tool_use_id, "call_1");
assert_eq!(content, "result 1");
assert_eq!(*is_error, Some(false));
}
_ => panic!("Expected ToolResult"),
},
_ => panic!("Expected Parts"),
}
assert_eq!(result.synthetic_messages.len(), 1);
}
#[test]
fn test_remove_orphaned_tool_result() {
let messages = vec![
Message {
role: Role::Assistant,
content: MessageContent::Text("just text".to_owned()),
},
Message {
role: Role::Tool,
content: MessageContent::Parts(vec![
ContentPart::ToolResult {
tool_use_id: "orphan_123".to_owned(),
content: "orphaned result".to_owned(),
is_error: Some(false),
},
]),
},
];
let result = repair_tool_result_pairing(messages);
assert_eq!(result.messages.len(), 1);
assert_eq!(result.messages[0].role, Role::Assistant);
assert!(result.synthetic_messages.is_empty());
}
}