use std::collections::HashMap;
use fuzzy_parser::distance::{find_closest, Algorithm};
use fuzzy_parser::{repair_object_fields, sanitize_json, FuzzyOptions, ObjectSchema};
use crate::decider::{ActionCandidate, DecisionResponse, LlmError};
pub const ACTION_FIELDS: ObjectSchema =
ObjectSchema::new(&["tool", "target", "args", "confidence"]);
pub fn candidate_names(candidates: &[ActionCandidate]) -> Vec<String> {
candidates.iter().map(|c| c.name.clone()).collect()
}
pub fn extract_json(text: &str) -> Result<String, LlmError> {
if let Some(start) = text.find("```json") {
let content_start = start + 7; let remaining = &text[content_start..];
if let Some(end) = remaining.find("```") {
let json = remaining[..end].trim();
if !json.is_empty() {
return Ok(json.to_string());
}
}
}
if let Some(json) = extract_balanced_json(text) {
return Ok(json);
}
Err(LlmError::permanent(format!(
"No JSON found in response: {}",
text
)))
}
fn extract_balanced_json(text: &str) -> Option<String> {
let start = text.find('{')?;
let chars: Vec<char> = text[start..].chars().collect();
let mut depth = 0;
let mut in_string = false;
let mut escape_next = false;
for (i, &ch) in chars.iter().enumerate() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' if in_string => escape_next = true,
'"' => in_string = !in_string,
'{' if !in_string => depth += 1,
'}' if !in_string => {
depth -= 1;
if depth == 0 {
return Some(chars[..=i].iter().collect());
}
}
_ => {}
}
}
None
}
fn fix_unquoted_keys(json: &str) -> String {
let mut result = String::with_capacity(json.len() * 2);
let chars: Vec<char> = json.chars().collect();
let len = chars.len();
let mut i = 0;
while i < len {
let ch = chars[i];
if ch == '{' || ch == ',' {
result.push(ch);
i += 1;
while i < len && chars[i].is_whitespace() {
result.push(chars[i]);
i += 1;
}
if i < len && (chars[i].is_alphabetic() || chars[i] == '_') {
let key_start = i;
while i < len && (chars[i].is_alphanumeric() || chars[i] == '_') {
i += 1;
}
let key: String = chars[key_start..i].iter().collect();
while i < len && chars[i].is_whitespace() {
i += 1;
}
if i < len && chars[i] == ':' {
result.push('"');
result.push_str(&key);
result.push('"');
result.push(':');
i += 1;
} else {
result.push_str(&key);
}
}
} else {
result.push(ch);
i += 1;
}
}
result
}
pub fn parse_json(json: &str, candidates: &[String]) -> Result<DecisionResponse, LlmError> {
let fixed_json = fix_unquoted_keys(json);
let mut parsed: serde_json::Value = serde_json::from_str(&fixed_json)
.map_err(|e| LlmError::permanent(format!("JSON parse error: {} (json: {})", e, json)))?;
let options = FuzzyOptions::default();
if let Some(obj) = parsed.as_object_mut() {
let corrections = repair_object_fields(obj, &ACTION_FIELDS, "$", &options);
if !corrections.is_empty() {
tracing::debug!(
corrections = ?corrections.iter().map(|c| format!("{} → {}", c.original, c.corrected)).collect::<Vec<_>>(),
"Fuzzy repaired field names"
);
}
}
let tool = if let Some(tool_val) = parsed["tool"].as_str() {
if candidates.iter().any(|c| c == tool_val) {
tool_val.to_string()
} else if !candidates.is_empty() {
let candidate_strs: Vec<&str> = candidates.iter().map(|s| s.as_str()).collect();
if let Some(m) = find_closest(tool_val, candidate_strs, 0.6, Algorithm::JaroWinkler) {
tracing::debug!(
original = tool_val,
corrected = %m.candidate,
similarity = m.similarity,
"Fuzzy repaired tool name"
);
m.candidate.to_string()
} else {
tool_val.to_string() }
} else {
tool_val.to_string() }
} else {
return Err(LlmError::permanent("Missing 'tool' field"));
};
let target = parsed["target"]
.as_str()
.ok_or_else(|| LlmError::permanent("Missing 'target' field"))?
.to_string();
let confidence = parsed["confidence"].as_f64().unwrap_or(0.5);
let mut args = HashMap::new();
if let Some(args_obj) = parsed["args"].as_object() {
for (key, value) in args_obj {
if let Some(v) = value.as_str() {
args.insert(key.clone(), v.to_string());
}
}
}
Ok(DecisionResponse {
tool,
target,
args,
reasoning: None,
confidence,
prompt: None,
raw_response: None,
})
}
pub fn parse_response(text: &str, candidates: &[String]) -> Result<DecisionResponse, LlmError> {
tracing::debug!(raw_output = %text, "LLM raw response");
let json_str = extract_json(text)?;
let sanitized = sanitize_json(&json_str);
tracing::debug!(sanitized = %sanitized, "Sanitized JSON");
parse_json(&sanitized, candidates)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_candidate_names() {
let candidates = vec![
ActionCandidate {
name: "Read".to_string(),
description: "Read a file".to_string(),
params: vec![],
example: None,
},
ActionCandidate {
name: "Write".to_string(),
description: "Write a file".to_string(),
params: vec![],
example: None,
},
];
let names = candidate_names(&candidates);
assert_eq!(names, vec!["Read".to_string(), "Write".to_string()]);
}
#[test]
fn test_candidate_names_empty() {
let candidates: Vec<ActionCandidate> = vec![];
let names = candidate_names(&candidates);
assert!(names.is_empty());
}
#[test]
fn test_extract_json_direct() {
let text = r#"{"tool": "Read", "target": "src/main.rs", "confidence": 0.8}"#;
let extracted = extract_json(text).unwrap();
assert_eq!(extracted, text);
}
#[test]
fn test_extract_json_with_prefix() {
let text =
r#"Here is the action: {"tool": "Read", "target": "file.rs", "confidence": 0.8}"#;
let extracted = extract_json(text).unwrap();
assert!(extracted.starts_with('{'));
assert!(extracted.ends_with('}'));
assert!(extracted.contains("Read"));
}
#[test]
fn test_extract_json_with_suffix() {
let text = r#"{"tool": "Grep", "target": "pattern", "confidence": 0.9} That's the action."#;
let extracted = extract_json(text).unwrap();
assert!(extracted.contains("Grep"));
}
#[test]
fn test_extract_json_markdown_block_with_newline() {
let text = "```json\n{\"tool\": \"Read\", \"target\": \"file.rs\"}\n```";
let extracted = extract_json(text).unwrap();
assert!(extracted.contains("\"tool\": \"Read\""));
}
#[test]
fn test_extract_json_markdown_block_without_newline() {
let text = "```json{\"tool\": \"Read\", \"target\": \"file.rs\"}```";
let extracted = extract_json(text).unwrap();
assert!(extracted.contains("\"tool\": \"Read\""));
}
#[test]
fn test_extract_json_no_json() {
let text = "This is just plain text without any JSON.";
let result = extract_json(text);
assert!(result.is_err());
assert!(result.unwrap_err().message().contains("No JSON found"));
}
#[test]
fn test_extract_json_nested_braces() {
let text = r#"{"tool": "Read", "target": "src/main.rs", "args": {"encoding": "utf8"}, "confidence": 0.8}"#;
let extracted = extract_json(text).unwrap();
assert!(extracted.contains("args"));
assert!(extracted.contains("encoding"));
}
#[test]
fn test_parse_json_basic() {
let candidates = vec!["Read".to_string(), "Write".to_string(), "Grep".to_string()];
let json = r#"{"tool": "Read", "target": "src/main.rs", "confidence": 0.85}"#;
let response = parse_json(json, &candidates).unwrap();
assert_eq!(response.tool, "Read");
assert_eq!(response.target, "src/main.rs");
assert!((response.confidence - 0.85).abs() < 0.01);
}
#[test]
fn test_parse_json_with_args() {
let candidates = vec!["Grep".to_string()];
let json =
r#"{"tool": "Grep", "target": "fn main", "args": {"path": "src/"}, "confidence": 0.9}"#;
let response = parse_json(json, &candidates).unwrap();
assert_eq!(response.tool, "Grep");
assert_eq!(response.target, "fn main");
assert_eq!(response.args.get("path"), Some(&"src/".to_string()));
}
#[test]
fn test_parse_json_default_confidence() {
let candidates = vec!["Read".to_string()];
let json = r#"{"tool": "Read", "target": "file.rs"}"#;
let response = parse_json(json, &candidates).unwrap();
assert!((response.confidence - 0.5).abs() < 0.01);
}
#[test]
fn test_parse_json_missing_tool() {
let candidates = vec!["Read".to_string()];
let json = r#"{"target": "file.rs", "confidence": 0.8}"#;
let result = parse_json(json, &candidates);
assert!(result.is_err());
assert!(result.unwrap_err().message().contains("tool"));
}
#[test]
fn test_parse_json_missing_target() {
let candidates = vec!["Read".to_string()];
let json = r#"{"tool": "Read", "confidence": 0.8}"#;
let result = parse_json(json, &candidates);
assert!(result.is_err());
assert!(result.unwrap_err().message().contains("target"));
}
#[test]
fn test_fuzzy_repair_tool_typo() {
let candidates = vec![
"Read".to_string(),
"Grep".to_string(),
"Dir".to_string(),
"Write".to_string(),
];
let json = r#"{"tool": "Raed", "target": "src/main.rs", "confidence": 0.8}"#;
let response = parse_json(json, &candidates).unwrap();
assert_eq!(response.tool, "Read");
let json = r#"{"tool": "Grpe", "target": "pattern", "confidence": 0.8}"#;
let response = parse_json(json, &candidates).unwrap();
assert_eq!(response.tool, "Grep");
}
#[test]
fn test_fuzzy_repair_field_typo() {
let candidates = vec!["Read".to_string()];
let json = r#"{"tool": "Read", "taget": "src/main.rs", "confidence": 0.8}"#;
let response = parse_json(json, &candidates).unwrap();
assert_eq!(response.target, "src/main.rs");
let json = r#"{"tool": "Read", "target": "file.rs", "confindence": 0.9}"#;
let response = parse_json(json, &candidates).unwrap();
assert!((response.confidence - 0.9).abs() < 0.01);
}
#[test]
fn test_fuzzy_repair_scoped_to_candidates() {
let candidates = vec!["Wait".to_string(), "Start".to_string()];
let json = r#"{"tool": "Writ", "target": "task", "confidence": 0.8}"#;
let response = parse_json(json, &candidates).unwrap();
assert_eq!(response.tool, "Wait"); }
#[test]
fn test_fuzzy_repair_no_candidates() {
let empty: Vec<String> = vec![];
let json = r#"{"tool": "Raed", "target": "file.rs", "confidence": 0.8}"#;
let response = parse_json(json, &empty).unwrap();
assert_eq!(response.tool, "Raed"); }
#[test]
fn test_parse_response_with_prefix() {
let candidates = vec!["Read".to_string(), "Grep".to_string()];
let text =
r#"I'll read the file: {"tool": "Read", "target": "src/main.rs", "confidence": 0.85}"#;
let response = parse_response(text, &candidates).unwrap();
assert_eq!(response.tool, "Read");
}
#[test]
fn test_parse_response_trailing_comma() {
let candidates = vec!["Read".to_string()];
let text = r#"{"tool": "Read", "target": "file.rs", "confidence": 0.8,}"#;
let response = parse_response(text, &candidates).unwrap();
assert_eq!(response.tool, "Read");
}
#[test]
fn test_parse_response_with_reasoning() {
let candidates = vec!["Grep".to_string()];
let text = r#"Based on the task, I should search for the pattern.
{"tool": "Grep", "target": "fn main", "confidence": 0.9}"#;
let response = parse_response(text, &candidates).unwrap();
assert_eq!(response.tool, "Grep");
assert_eq!(response.target, "fn main");
}
#[test]
fn test_parse_response_combined_repairs() {
let candidates = vec!["Read".to_string(), "Write".to_string()];
let text = r#"{"tool": "Raed", "taget": "src/lib.rs", "confindence": 0.75,}"#;
let response = parse_response(text, &candidates).unwrap();
assert_eq!(response.tool, "Read");
assert_eq!(response.target, "src/lib.rs");
assert!((response.confidence - 0.75).abs() < 0.01);
}
}