use std::sync::Arc;
use crate::capabilities::{Capability, CapabilityLocalization};
use serde_json::Value;
pub const TOOL_CALL_REPAIR_CAPABILITY_ID: &str = "tool_call_repair";
pub const DEFAULT_MAX_REPROMPTS: u32 = 1;
pub const MAX_SALVAGE_INPUT_BYTES: usize = 256 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RepairOutcome {
LocalSalvage,
Reprompt,
GaveUp,
}
impl RepairOutcome {
pub fn label(self) -> &'static str {
match self {
RepairOutcome::LocalSalvage => "local-salvage",
RepairOutcome::Reprompt => "re-prompt",
RepairOutcome::GaveUp => "gave-up",
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SalvageResult {
AlreadyValid,
Repaired(Value),
Unsalvageable,
}
pub fn salvage_tool_arguments(raw: &Value, schema: Option<&Value>) -> SalvageResult {
if let Value::Object(_) = raw {
let coerced = coerce_known_keys(raw.clone(), schema);
if let Value::Object(ref m) = coerced
&& violates_schema(m, schema)
{
return SalvageResult::Unsalvageable;
}
if coerced == *raw {
return SalvageResult::AlreadyValid;
}
return SalvageResult::Repaired(coerced);
}
let Some(raw_str) = raw.as_str() else {
return SalvageResult::Unsalvageable;
};
if raw_str.len() > MAX_SALVAGE_INPUT_BYTES {
return SalvageResult::Unsalvageable;
}
let trimmed = raw_str.trim();
if trimmed.is_empty() {
let empty = serde_json::Map::new();
if violates_schema(&empty, schema) {
return SalvageResult::Unsalvageable;
}
return SalvageResult::Repaired(Value::Object(empty));
}
match extract_json_object(trimmed) {
Some(obj) => {
let coerced = coerce_known_keys(obj, schema);
if let Value::Object(ref m) = coerced
&& violates_schema(m, schema)
{
return SalvageResult::Unsalvageable;
}
SalvageResult::Repaired(coerced)
}
None => SalvageResult::Unsalvageable,
}
}
fn violates_schema(obj: &serde_json::Map<String, Value>, schema: Option<&Value>) -> bool {
let Some(schema) = schema else {
return false;
};
if let Some(required) = schema.get("required").and_then(Value::as_array) {
for key in required.iter().filter_map(Value::as_str) {
if !obj.contains_key(key) {
return true;
}
}
}
if let Some(props) = schema.get("properties").and_then(Value::as_object) {
for (key, prop_schema) in props {
let Some(declared) = prop_schema.get("type").and_then(Value::as_str) else {
continue;
};
let Some(val) = obj.get(key) else {
continue;
};
if val.is_null() {
continue;
}
let matches = match declared {
"integer" => val.is_i64() || val.is_u64(),
"number" => val.is_number(),
"boolean" => val.is_boolean(),
"string" => val.is_string(),
"array" => val.is_array(),
"object" => val.is_object(),
_ => true,
};
if !matches {
return true;
}
}
}
false
}
fn extract_json_object(input: &str) -> Option<Value> {
let candidate = strip_code_fences(input);
if let Ok(value @ Value::Object(_)) = serde_json::from_str::<Value>(candidate.trim()) {
return Some(value);
}
let span = first_balanced_object_span(candidate)?;
let slice = &candidate[span];
if let Ok(value @ Value::Object(_)) = serde_json::from_str::<Value>(slice) {
return Some(value);
}
let relaxed = relax_json(slice);
match serde_json::from_str::<Value>(&relaxed) {
Ok(value @ Value::Object(_)) => Some(value),
_ => None,
}
}
fn strip_code_fences(input: &str) -> &str {
let trimmed = input.trim();
let Some(after_open) = trimmed.strip_prefix("```") else {
return trimmed;
};
let after_lang = match after_open.find('\n') {
Some(nl) => &after_open[nl + 1..],
None => after_open,
};
after_lang.strip_suffix("```").unwrap_or(after_lang).trim()
}
fn first_balanced_object_span(input: &str) -> Option<std::ops::Range<usize>> {
let bytes = input.as_bytes();
let start = bytes.iter().position(|&b| b == b'{')?;
let mut depth: u32 = 0;
let mut in_string = false;
let mut escaped = false;
let mut quote: u8 = 0;
for (i, &b) in bytes.iter().enumerate().skip(start) {
if in_string {
if escaped {
escaped = false;
} else if b == b'\\' {
escaped = true;
} else if b == quote {
in_string = false;
}
continue;
}
match b {
b'"' | b'\'' => {
in_string = true;
quote = b;
}
b'{' => depth += 1,
b'}' => {
depth -= 1;
if depth == 0 {
return Some(start..i + 1);
}
}
_ => {}
}
}
None
}
fn relax_json(slice: &str) -> String {
let mut out = String::with_capacity(slice.len());
let mut in_double = false;
let mut escaped = false;
for ch in slice.chars() {
if in_double {
out.push(ch);
if escaped {
escaped = false;
} else if ch == '\\' {
escaped = true;
} else if ch == '"' {
in_double = false;
}
continue;
}
match ch {
'"' => {
in_double = true;
out.push(ch);
}
'\'' => out.push('"'),
_ => out.push(ch),
}
}
strip_trailing_commas(&out)
}
fn strip_trailing_commas(input: &str) -> String {
let chars: Vec<char> = input.chars().collect();
let mut out = String::with_capacity(input.len());
let mut in_string = false;
let mut escaped = false;
for i in 0..chars.len() {
let ch = chars[i];
if in_string {
out.push(ch);
if escaped {
escaped = false;
} else if ch == '\\' {
escaped = true;
} else if ch == '"' {
in_string = false;
}
continue;
}
if ch == '"' {
in_string = true;
out.push(ch);
continue;
}
if ch == ',' {
let mut j = i + 1;
while j < chars.len() && chars[j].is_whitespace() {
j += 1;
}
if j < chars.len() && (chars[j] == '}' || chars[j] == ']') {
continue;
}
}
out.push(ch);
}
out
}
fn coerce_known_keys(value: Value, schema: Option<&Value>) -> Value {
let Value::Object(mut obj) = value else {
return value;
};
let Some(props) = schema
.and_then(|s| s.get("properties"))
.and_then(Value::as_object)
else {
return Value::Object(obj);
};
for (key, prop_schema) in props {
let Some(declared) = prop_schema.get("type").and_then(Value::as_str) else {
continue;
};
let Some(current) = obj.get(key) else {
continue;
};
let Some(text) = current.as_str() else {
continue;
};
let coerced = match declared {
"integer" => text.trim().parse::<i64>().ok().map(Value::from),
"number" => text.trim().parse::<f64>().ok().map(Value::from),
"boolean" => match text.trim() {
"true" => Some(Value::Bool(true)),
"false" => Some(Value::Bool(false)),
_ => None,
},
_ => None,
};
if let Some(coerced) = coerced {
obj.insert(key.clone(), coerced);
}
}
Value::Object(obj)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ToolCallRepairConfig {
pub max_reprompts: u32,
}
impl Default for ToolCallRepairConfig {
fn default() -> Self {
Self {
max_reprompts: DEFAULT_MAX_REPROMPTS,
}
}
}
impl ToolCallRepairConfig {
pub fn from_json(config: &Value) -> Self {
let max_reprompts = config
.get("max_reprompts")
.and_then(Value::as_u64)
.map(|v| v as u32)
.unwrap_or(DEFAULT_MAX_REPROMPTS);
Self { max_reprompts }
}
pub fn outcome_after_failed_salvage(&self, prior_attempts: u32) -> RepairOutcome {
if prior_attempts < self.max_reprompts {
RepairOutcome::Reprompt
} else {
RepairOutcome::GaveUp
}
}
}
pub struct ToolCallRepairCapability;
impl Capability for ToolCallRepairCapability {
fn id(&self) -> &str {
TOOL_CALL_REPAIR_CAPABILITY_ID
}
fn name(&self) -> &str {
"Tool Call Repair"
}
fn description(&self) -> &str {
"Detects and repairs malformed tool-call arguments from the model, \
recovering the turn instead of surfacing a raw parse error."
}
fn is_guardrail(&self) -> bool {
true
}
fn config_schema(&self) -> Option<Value> {
Some(serde_json::json!({
"type": "object",
"properties": {
"max_reprompts": {
"type": "integer",
"title": "Max corrective re-prompts",
"description": "How many corrective re-prompt attempts are allowed per malformed tool call before falling through to the normal error path.",
"minimum": 0,
"maximum": 5,
"default": DEFAULT_MAX_REPROMPTS
}
}
}))
}
fn validate_config(&self, config: &Value) -> Result<(), String> {
if config.is_null() {
return Ok(());
}
if !config.is_object() {
return Err("tool_call_repair config must be an object".to_string());
}
match config.get("max_reprompts") {
None => Ok(()),
Some(value) => match value.as_u64() {
Some(n) if n <= 5 => Ok(()),
_ => Err(format!(
"max_reprompts must be an integer between 0 and 5, got {value}"
)),
},
}
}
fn localizations(&self) -> Vec<CapabilityLocalization> {
vec![CapabilityLocalization {
locale: "en",
name: None,
description: None,
config_description: Some(
"Controls how many corrective re-prompts are attempted before a malformed tool call falls through to the normal error path.",
),
config_overlay: None,
}]
}
}
pub fn tool_call_repair_capability() -> Arc<dyn Capability> {
Arc::new(ToolCallRepairCapability)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn schema() -> Value {
json!({
"type": "object",
"properties": {
"path": { "type": "string" },
"limit": { "type": "integer" },
"ratio": { "type": "number" },
"recursive": { "type": "boolean" }
},
"required": ["path"]
})
}
#[test]
fn already_valid_object_is_noop() {
let raw = json!({ "path": "/foo", "limit": 10 });
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::AlreadyValid
);
}
#[test]
fn already_valid_object_without_schema_is_noop() {
let raw = json!({ "anything": 1 });
assert_eq!(
salvage_tool_arguments(&raw, None),
SalvageResult::AlreadyValid
);
}
#[test]
fn empty_string_becomes_empty_object() {
let raw = json!(" ");
assert_eq!(
salvage_tool_arguments(&raw, None),
SalvageResult::Repaired(json!({}))
);
}
#[test]
fn empty_string_with_required_schema_is_unsalvageable() {
let raw = json!("");
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Unsalvageable
);
}
#[test]
fn object_missing_required_key_is_unsalvageable() {
let raw = json!({ "limit": 3 });
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Unsalvageable
);
}
#[test]
fn object_with_uncoercible_type_is_unsalvageable() {
let raw = json!({ "path": "/foo", "limit": "abc" });
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Unsalvageable
);
}
#[test]
fn extracted_object_missing_required_is_unsalvageable() {
let raw = json!("here you go: {\"limit\": 3}");
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Unsalvageable
);
}
#[test]
fn object_missing_required_key_without_schema_is_noop() {
let raw = json!({ "limit": 3 });
assert_eq!(
salvage_tool_arguments(&raw, None),
SalvageResult::AlreadyValid
);
}
#[test]
fn raw_string_object_is_parsed() {
let raw = json!("{\"path\": \"/foo\"}");
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Repaired(json!({ "path": "/foo" }))
);
}
#[test]
fn fenced_json_block_is_unwrapped() {
let raw = json!("```json\n{\"path\": \"/foo\"}\n```");
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Repaired(json!({ "path": "/foo" }))
);
}
#[test]
fn bare_fenced_block_is_unwrapped() {
let raw = json!("```\n{\"path\": \"/bar\"}\n```");
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Repaired(json!({ "path": "/bar" }))
);
}
#[test]
fn leading_and_trailing_prose_is_stripped() {
let raw = json!("Sure! Here are the args: {\"path\": \"/foo\"} hope that helps");
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Repaired(json!({ "path": "/foo" }))
);
}
#[test]
fn trailing_commas_are_removed() {
let raw = json!("{\"path\": \"/foo\", \"limit\": 3,}");
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Repaired(json!({ "path": "/foo", "limit": 3 }))
);
}
#[test]
fn single_quotes_are_normalized() {
let raw = json!("{'path': '/foo', 'limit': 5}");
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Repaired(json!({ "path": "/foo", "limit": 5 }))
);
}
#[test]
fn apostrophe_inside_double_quoted_value_survives() {
let raw = json!("{\"path\": \"it's here\"}");
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Repaired(json!({ "path": "it's here" }))
);
}
#[test]
fn brace_inside_string_does_not_break_span() {
let raw = json!("prose {\"path\": \"a}b\"} more");
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Repaired(json!({ "path": "a}b" }))
);
}
#[test]
fn known_keys_are_coerced_against_schema() {
let raw = json!(
"{\"path\": \"/foo\", \"limit\": \"42\", \"ratio\": \"1.5\", \"recursive\": \"true\"}"
);
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Repaired(
json!({ "path": "/foo", "limit": 42, "ratio": 1.5, "recursive": true })
)
);
}
#[test]
fn object_with_string_typed_known_key_is_coerced_in_place() {
let raw = json!({ "path": "/foo", "limit": "7" });
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Repaired(json!({ "path": "/foo", "limit": 7 }))
);
}
#[test]
fn unparseable_garbage_is_unsalvageable() {
let raw = json!("path equals slash foo, no json here at all");
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Unsalvageable
);
}
#[test]
fn oversized_input_is_rejected_without_parsing() {
let big = format!("{{\"path\": \"{}\"}}", "a".repeat(MAX_SALVAGE_INPUT_BYTES));
let raw = json!(big);
assert_eq!(
salvage_tool_arguments(&raw, Some(&schema())),
SalvageResult::Unsalvageable
);
}
#[test]
fn non_object_non_string_is_unsalvageable() {
assert_eq!(
salvage_tool_arguments(&json!(42), None),
SalvageResult::Unsalvageable
);
assert_eq!(
salvage_tool_arguments(&json!([1, 2]), None),
SalvageResult::Unsalvageable
);
}
#[test]
fn outcome_reprompts_until_cap_then_gives_up() {
let cfg = ToolCallRepairConfig { max_reprompts: 2 };
assert_eq!(cfg.outcome_after_failed_salvage(0), RepairOutcome::Reprompt);
assert_eq!(cfg.outcome_after_failed_salvage(1), RepairOutcome::Reprompt);
assert_eq!(cfg.outcome_after_failed_salvage(2), RepairOutcome::GaveUp);
assert_eq!(cfg.outcome_after_failed_salvage(3), RepairOutcome::GaveUp);
}
#[test]
fn zero_reprompts_gives_up_immediately() {
let cfg = ToolCallRepairConfig { max_reprompts: 0 };
assert_eq!(cfg.outcome_after_failed_salvage(0), RepairOutcome::GaveUp);
}
#[test]
fn config_parses_from_json_with_defaults() {
assert_eq!(
ToolCallRepairConfig::from_json(&json!({})),
ToolCallRepairConfig::default()
);
assert_eq!(
ToolCallRepairConfig::from_json(&json!({ "max_reprompts": 3 })).max_reprompts,
3
);
}
#[test]
fn outcome_labels_are_stable() {
assert_eq!(RepairOutcome::LocalSalvage.label(), "local-salvage");
assert_eq!(RepairOutcome::Reprompt.label(), "re-prompt");
assert_eq!(RepairOutcome::GaveUp.label(), "gave-up");
}
#[test]
fn capability_id_and_validation() {
let cap = ToolCallRepairCapability;
assert_eq!(cap.id(), TOOL_CALL_REPAIR_CAPABILITY_ID);
assert!(cap.is_guardrail());
assert!(cap.config_schema().is_some());
assert!(cap.validate_config(&Value::Null).is_ok());
assert!(cap.validate_config(&json!({})).is_ok());
assert!(cap.validate_config(&json!({ "max_reprompts": 2 })).is_ok());
assert!(cap.validate_config(&json!({ "max_reprompts": 9 })).is_err());
assert!(
cap.validate_config(&json!({ "max_reprompts": "x" }))
.is_err()
);
}
}