use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LimitAction {
SkipStep,
AbortTask,
Escalate,
}
impl Default for LimitAction {
fn default() -> Self {
Self::SkipStep
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepIterationConfig {
#[serde(default = "default_max_reattempts")]
pub max_reattempts_per_step: u32,
#[serde(default = "default_similarity_threshold")]
pub similarity_threshold: f64,
#[serde(default)]
pub on_limit_reached: LimitAction,
}
fn default_max_reattempts() -> u32 {
2
}
fn default_similarity_threshold() -> f64 {
0.85
}
impl Default for StepIterationConfig {
fn default() -> Self {
Self {
max_reattempts_per_step: default_max_reattempts(),
similarity_threshold: default_similarity_threshold(),
on_limit_reached: LimitAction::default(),
}
}
}
#[derive(Debug, Clone)]
pub enum StepDecision {
Continue,
Stop { reason: String },
}
#[derive(Debug)]
struct StepProgress {
attempts: u32,
last_error: Option<String>,
}
pub struct ProgressTracker {
config: StepIterationConfig,
steps: HashMap<String, StepProgress>,
}
impl ProgressTracker {
pub fn new(config: StepIterationConfig) -> Self {
Self {
config,
steps: HashMap::new(),
}
}
pub fn begin_step(&mut self, step_id: impl Into<String>) {
let id = step_id.into();
self.steps.insert(
id,
StepProgress {
attempts: 0,
last_error: None,
},
);
}
pub fn record_attempt(&mut self, step_id: &str, error_output: impl Into<String>) {
let output = error_output.into();
if let Some(progress) = self.steps.get_mut(step_id) {
progress.attempts += 1;
progress.last_error = Some(output);
}
}
pub fn should_continue(&self, step_id: &str) -> StepDecision {
let progress = match self.steps.get(step_id) {
Some(p) => p,
None => return StepDecision::Continue,
};
if progress.attempts >= self.config.max_reattempts_per_step {
return StepDecision::Stop {
reason: format!(
"Step '{}' reached max reattempts ({})",
step_id, self.config.max_reattempts_per_step
),
};
}
StepDecision::Continue
}
pub fn record_and_check(
&mut self,
step_id: &str,
error_output: impl Into<String>,
) -> StepDecision {
let output = error_output.into();
if let Some(progress) = self.steps.get(step_id) {
if let Some(ref last) = progress.last_error {
let similarity = normalized_levenshtein(last, &output);
if similarity >= self.config.similarity_threshold {
let attempts = progress.attempts + 1;
if let Some(p) = self.steps.get_mut(step_id) {
p.attempts = attempts;
p.last_error = Some(output);
}
return StepDecision::Stop {
reason: format!(
"Step '{}' appears stuck: consecutive errors are {:.0}% similar (threshold: {:.0}%)",
step_id,
similarity * 100.0,
self.config.similarity_threshold * 100.0
),
};
}
}
}
self.record_attempt(step_id, &output);
self.should_continue(step_id)
}
pub fn attempt_count(&self, step_id: &str) -> u32 {
self.steps.get(step_id).map_or(0, |p| p.attempts)
}
pub fn reset_step(&mut self, step_id: &str) {
self.steps.remove(step_id);
}
pub fn limit_action(&self) -> &LimitAction {
&self.config.on_limit_reached
}
}
fn normalized_levenshtein(a: &str, b: &str) -> f64 {
let a_len = a.chars().count();
let b_len = b.chars().count();
if a_len == 0 && b_len == 0 {
return 1.0;
}
if a_len == 0 || b_len == 0 {
return 0.0;
}
let (short, long, short_len, long_len) = if a_len <= b_len {
(a, b, a_len, b_len)
} else {
(b, a, b_len, a_len)
};
let mut prev_row: Vec<usize> = (0..=short_len).collect();
let mut curr_row = vec![0usize; short_len + 1];
for (i, long_ch) in long.chars().enumerate() {
curr_row[0] = i + 1;
for (j, short_ch) in short.chars().enumerate() {
let cost = if long_ch == short_ch { 0 } else { 1 };
curr_row[j + 1] = (prev_row[j] + cost)
.min(prev_row[j + 1] + 1)
.min(curr_row[j] + 1);
}
std::mem::swap(&mut prev_row, &mut curr_row);
}
let distance = prev_row[short_len];
let max_len = long_len;
1.0 - (distance as f64 / max_len as f64)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_levenshtein_identical() {
assert!((normalized_levenshtein("hello", "hello") - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_levenshtein_empty_both() {
assert!((normalized_levenshtein("", "") - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_levenshtein_one_empty() {
assert!((normalized_levenshtein("", "hello") - 0.0).abs() < f64::EPSILON);
assert!((normalized_levenshtein("hello", "") - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_levenshtein_similar() {
let sim = normalized_levenshtein("kitten", "sitting");
assert!(sim > 0.5 && sim < 0.6);
}
#[test]
fn test_levenshtein_completely_different() {
let sim = normalized_levenshtein("abc", "xyz");
assert!(sim < 0.1);
}
#[test]
fn test_continue_on_first_attempt() {
let mut tracker = ProgressTracker::new(StepIterationConfig::default());
tracker.begin_step("step1");
assert!(matches!(
tracker.should_continue("step1"),
StepDecision::Continue
));
}
#[test]
fn test_stop_at_max_attempts() {
let mut tracker = ProgressTracker::new(StepIterationConfig {
max_reattempts_per_step: 2,
..Default::default()
});
tracker.begin_step("step1");
tracker.record_attempt("step1", "error A");
assert!(matches!(
tracker.should_continue("step1"),
StepDecision::Continue
));
tracker.record_attempt("step1", "error B");
assert!(matches!(
tracker.should_continue("step1"),
StepDecision::Stop { .. }
));
}
#[test]
fn test_stop_on_similar_errors() {
let mut tracker = ProgressTracker::new(StepIterationConfig {
max_reattempts_per_step: 10, similarity_threshold: 0.85,
..Default::default()
});
tracker.begin_step("step1");
tracker.record_attempt("step1", "connection timeout to api.example.com:443");
let decision =
tracker.record_and_check("step1", "connection timeout to api.example.com:443");
assert!(matches!(decision, StepDecision::Stop { .. }));
}
#[test]
fn test_continue_on_different_errors() {
let mut tracker = ProgressTracker::new(StepIterationConfig {
max_reattempts_per_step: 10,
similarity_threshold: 0.85,
..Default::default()
});
tracker.begin_step("step1");
tracker.record_attempt("step1", "connection timeout to api.example.com");
let decision = tracker.record_and_check("step1", "permission denied for /etc/secret");
assert!(matches!(decision, StepDecision::Continue));
}
#[test]
fn test_begin_step_resets() {
let mut tracker = ProgressTracker::new(StepIterationConfig::default());
tracker.begin_step("step1");
tracker.record_attempt("step1", "error");
assert_eq!(tracker.attempt_count("step1"), 1);
tracker.begin_step("step1");
assert_eq!(tracker.attempt_count("step1"), 0);
}
#[test]
fn test_reset_step_removes() {
let mut tracker = ProgressTracker::new(StepIterationConfig::default());
tracker.begin_step("step1");
tracker.record_attempt("step1", "error");
tracker.reset_step("step1");
assert_eq!(tracker.attempt_count("step1"), 0);
}
#[test]
fn test_default_config_values() {
let config = StepIterationConfig::default();
assert_eq!(config.max_reattempts_per_step, 2);
assert!((config.similarity_threshold - 0.85).abs() < f64::EPSILON);
assert!(matches!(config.on_limit_reached, LimitAction::SkipStep));
}
}