use regex::Regex;
use serde::Serialize;
use std::sync::LazyLock;
pub const MAX_NORMALIZATION_RETRIES: u8 = 2;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum NormalizationPattern {
MalformedToolCall,
NarratedToolUse,
EmptyAction,
}
static RE_ACTION_OR_TOOL_CALL: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?i)(Action\s*:|tool_call)").unwrap());
static RE_NARRATED: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?i)(I would use|I'll run|let me use the|I should call|I need to invoke)\s+\w+")
.unwrap()
});
static RE_EMPTY_ACTION: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?i)Action\s*:\s*$").unwrap());
pub fn detect_normalization_failure(content: &str) -> Option<NormalizationPattern> {
if content.trim().is_empty() {
return Some(NormalizationPattern::EmptyAction);
}
if RE_EMPTY_ACTION.is_match(content) {
return Some(NormalizationPattern::EmptyAction);
}
if RE_ACTION_OR_TOOL_CALL.is_match(content) {
if has_broken_json(content) {
return Some(NormalizationPattern::MalformedToolCall);
}
return None;
}
if RE_NARRATED.is_match(content) {
return Some(NormalizationPattern::NarratedToolUse);
}
None
}
pub fn build_normalization_retry_prompt(
pattern: &NormalizationPattern,
tool_count: usize,
) -> String {
match pattern {
NormalizationPattern::MalformedToolCall => {
"Your previous response contained a malformed tool call. \
Please retry using the correct JSON format:\n\
Action: tool_name\n\
Action Input: {\"param\": \"value\"}"
.to_string()
}
NormalizationPattern::NarratedToolUse => {
format!(
"You described what you would do instead of doing it. \
Use the Action/Action Input format to actually invoke the tool. \
You have {tool_count} tools available."
)
}
NormalizationPattern::EmptyAction => "Your previous response was empty. \
Please provide either a direct answer or use a tool via \
Action/Action Input format."
.to_string(),
}
}
fn has_broken_json(content: &str) -> bool {
let open: i32 = content.chars().filter(|&c| c == '{').count() as i32;
let close: i32 = content.chars().filter(|&c| c == '}').count() as i32;
if open == 0 {
return true;
}
open != close
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detects_malformed_tool_call_unbalanced_braces() {
let content = "Action: web_search\nAction Input: {\"query\": \"rust async\"";
assert_eq!(
detect_normalization_failure(content),
Some(NormalizationPattern::MalformedToolCall)
);
}
#[test]
fn detects_malformed_tool_call_no_json() {
let content = "Action: web_search\nAction Input: query rust async";
assert_eq!(
detect_normalization_failure(content),
Some(NormalizationPattern::MalformedToolCall)
);
}
#[test]
fn detects_narrated_tool_use() {
let content = "I would use web_search to find recent articles on the topic.";
assert_eq!(
detect_normalization_failure(content),
Some(NormalizationPattern::NarratedToolUse)
);
}
#[test]
fn detects_empty_action_whitespace_only() {
let content = " \n\t ";
assert_eq!(
detect_normalization_failure(content),
Some(NormalizationPattern::EmptyAction)
);
}
#[test]
fn normal_tool_call_not_detected_as_failure() {
let content = "Action: web_search\nAction Input: {\"query\": \"rust async\"}";
assert_eq!(detect_normalization_failure(content), None);
}
#[test]
fn normal_text_response_not_detected_as_failure() {
let content = "The answer is 42. Rust is a systems programming language.";
assert_eq!(detect_normalization_failure(content), None);
}
#[test]
fn retry_prompt_includes_tool_count() {
let prompt = build_normalization_retry_prompt(&NormalizationPattern::NarratedToolUse, 7);
assert!(prompt.contains("7 tools available"));
}
#[test]
fn detects_empty_action_line() {
let content = "Thought: I should search for this.\nAction: ";
assert_eq!(
detect_normalization_failure(content),
Some(NormalizationPattern::EmptyAction)
);
}
}