use async_trait::async_trait;
use serde_json::Value;
use crate::middleware::{AgentState, Middleware, Result};
pub struct PatchToolCallsMiddleware {
known_tools: Vec<String>,
fix_json: bool,
}
impl PatchToolCallsMiddleware {
pub fn new(known_tools: Vec<String>) -> Self {
Self {
known_tools,
fix_json: true,
}
}
pub fn set_fix_json(&mut self, fix: bool) -> &mut Self {
self.fix_json = fix;
self
}
}
fn levenshtein(a: &str, b: &str) -> usize {
let a_len = a.len();
let b_len = b.len();
if a_len == 0 {
return b_len;
}
if b_len == 0 {
return a_len;
}
let mut prev: Vec<usize> = (0..=b_len).collect();
let mut curr = vec![0; b_len + 1];
for (i, ca) in a.chars().enumerate() {
curr[0] = i + 1;
for (j, cb) in b.chars().enumerate() {
let cost = if ca == cb { 0 } else { 1 };
curr[j + 1] = (prev[j] + cost).min(prev[j + 1] + 1).min(curr[j] + 1);
}
std::mem::swap(&mut prev, &mut curr);
}
prev[b_len]
}
fn find_closest_tool(name: &str, known: &[String]) -> Option<String> {
if known.is_empty() {
return None;
}
let mut best: Option<(&String, usize)> = None;
for tool in known {
let dist = levenshtein(name, tool);
if best.is_none() || dist < best.unwrap().1 {
best = Some((tool, dist));
}
}
let (best_tool, best_dist) = best?;
let threshold = (name.len() / 2).max(2);
if best_dist <= threshold {
Some(best_tool.clone())
} else {
None
}
}
fn repair_json(input: &str) -> String {
let mut result = input.to_string();
result = result.replace('\'', "\"");
loop {
let before = result.clone();
result = remove_trailing_commas(&result);
if result == before {
break;
}
}
result
}
fn remove_trailing_commas(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let chars: Vec<char> = s.chars().collect();
let len = chars.len();
let mut i = 0;
while i < len {
if chars[i] == ',' {
let mut j = i + 1;
while j < len && chars[j].is_whitespace() {
j += 1;
}
if j < len && (chars[j] == '}' || chars[j] == ']') {
i += 1;
continue;
}
}
result.push(chars[i]);
i += 1;
}
result
}
#[async_trait]
impl Middleware for PatchToolCallsMiddleware {
fn name(&self) -> &str {
"patch_tool_calls"
}
async fn after_model(&self, state: &mut AgentState) -> Result<()> {
if self.known_tools.is_empty() {
return Ok(());
}
let messages = match state.get_mut("messages").and_then(|v| v.as_array_mut()) {
Some(m) => m,
None => return Ok(()),
};
let last_ai = messages
.iter_mut()
.rev()
.find(|m| m.get("type").and_then(|t| t.as_str()) == Some("ai"));
let ai_msg = match last_ai {
Some(m) => m,
None => return Ok(()),
};
let tool_calls = match ai_msg.get_mut("tool_calls").and_then(|v| v.as_array_mut()) {
Some(tc) => tc,
None => return Ok(()),
};
for tool_call in tool_calls.iter_mut() {
if let Some(name_val) = tool_call.get_mut("name") {
if let Some(name) = name_val.as_str().map(|s| s.to_string()) {
if !self.known_tools.contains(&name) {
if let Some(closest) = find_closest_tool(&name, &self.known_tools) {
*name_val = Value::String(closest);
}
}
}
}
if self.fix_json {
if let Some(args_val) = tool_call.get_mut("args") {
if let Some(args_str) = args_val.as_str() {
let repaired = repair_json(args_str);
if let Ok(parsed) = serde_json::from_str::<Value>(&repaired) {
*args_val = parsed;
} else {
*args_val = Value::String(repaired);
}
} else if args_val.is_object() {
}
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_levenshtein() {
assert_eq!(levenshtein("", ""), 0);
assert_eq!(levenshtein("abc", "abc"), 0);
assert_eq!(levenshtein("abc", ""), 3);
assert_eq!(levenshtein("", "abc"), 3);
assert_eq!(levenshtein("kitten", "sitting"), 3);
assert_eq!(levenshtein("calculater", "calculator"), 1);
}
#[tokio::test]
async fn test_fix_misspelled_tool_name() {
let mw = PatchToolCallsMiddleware::new(vec![
"calculator".to_string(),
"search".to_string(),
"read_file".to_string(),
]);
let mut state = json!({
"messages": [
{ "type": "human", "content": "help" },
{
"type": "ai",
"content": "",
"tool_calls": [
{ "name": "calculater", "args": {"expr": "2+2"} }
]
}
]
});
mw.after_model(&mut state).await.unwrap();
let tool_calls = state["messages"][1]["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls[0]["name"], "calculator");
}
#[tokio::test]
async fn test_fix_json_trailing_comma() {
let mw = PatchToolCallsMiddleware::new(vec!["calculator".to_string()]);
let mut state = json!({
"messages": [
{
"type": "ai",
"content": "",
"tool_calls": [
{
"name": "calculator",
"args": "{\"expr\": \"2+2\", }"
}
]
}
]
});
mw.after_model(&mut state).await.unwrap();
let args = &state["messages"][0]["tool_calls"][0]["args"];
assert_eq!(args["expr"], "2+2");
}
#[tokio::test]
async fn test_valid_tool_calls_not_modified() {
let mw =
PatchToolCallsMiddleware::new(vec!["calculator".to_string(), "search".to_string()]);
let mut state = json!({
"messages": [
{
"type": "ai",
"content": "",
"tool_calls": [
{ "name": "calculator", "args": {"expr": "2+2"} }
]
}
]
});
let original_state = state.clone();
mw.after_model(&mut state).await.unwrap();
assert_eq!(state, original_state);
}
#[tokio::test]
async fn test_no_known_tools_no_patching() {
let mw = PatchToolCallsMiddleware::new(vec![]);
let mut state = json!({
"messages": [
{
"type": "ai",
"content": "",
"tool_calls": [
{ "name": "nonexistent", "args": "{bad json,}" }
]
}
]
});
let original_state = state.clone();
mw.after_model(&mut state).await.unwrap();
assert_eq!(state, original_state);
}
#[test]
fn test_repair_json_single_quotes() {
let input = "{'key': 'value'}";
let repaired = repair_json(input);
let parsed: Value = serde_json::from_str(&repaired).unwrap();
assert_eq!(parsed["key"], "value");
}
#[test]
fn test_find_closest_tool_empty() {
assert_eq!(find_closest_tool("anything", &[]), None);
}
#[test]
fn test_find_closest_tool_exact_match() {
let tools = vec!["calculator".to_string()];
assert_eq!(
find_closest_tool("calculator", &tools),
Some("calculator".to_string())
);
}
}